Skip to content

Commit

Permalink
Address PR feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
samin-cf committed Nov 25, 2024
1 parent 0e56247 commit fce8f21
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 43 deletions.
2 changes: 1 addition & 1 deletion src/async_query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ impl WMIConnection {
&query_language,
&query,
WBEM_FLAG_BIDIRECTIONAL,
&self.ctx,
&self.ctx.0,
&p_sink_handle,
)?;
}
Expand Down
18 changes: 4 additions & 14 deletions src/connection.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use crate::context::WMIContext;
use crate::utils::WMIResult;
use crate::WMIError;
use log::debug;
Expand All @@ -13,8 +14,7 @@ use windows::Win32::System::Com::{
};
use windows::Win32::System::Rpc::{RPC_C_AUTHN_WINNT, RPC_C_AUTHZ_NONE};
use windows::Win32::System::Wmi::{
IWbemContext, IWbemLocator, IWbemServices, WbemContext, WbemLocator,
WBEM_FLAG_CONNECT_USE_MAX_WAIT,
IWbemLocator, IWbemServices, WbemLocator, WBEM_FLAG_CONNECT_USE_MAX_WAIT,
};

/// A marker to indicate that the current thread was `CoInitialize`d.
Expand Down Expand Up @@ -127,7 +127,7 @@ fn _test_com_lib_not_send(_s: impl Send) {}
pub struct WMIConnection {
_com_con: COMLibrary,
pub svc: IWbemServices,
pub ctx: IWbemContext,
pub(crate) ctx: WMIContext,
}

/// A connection to the local WMI provider, which provides querying capabilities.
Expand All @@ -153,7 +153,7 @@ impl WMIConnection {
pub fn with_namespace_path(namespace_path: &str, com_lib: COMLibrary) -> WMIResult<Self> {
let loc = create_locator()?;
let svc = create_services(&loc, namespace_path)?;
let ctx = create_context()?;
let ctx = WMIContext::new()?;

let this = Self {
_com_con: com_lib,
Expand Down Expand Up @@ -195,16 +195,6 @@ fn create_locator() -> WMIResult<IWbemLocator> {
Ok(loc)
}

fn create_context() -> WMIResult<IWbemContext> {
debug!("Calling CoCreateInstance for CLSID_WbemContext");

let ctx = unsafe { CoCreateInstance(&WbemContext, None, CLSCTX_INPROC_SERVER)? };

debug!("Got context {:?}", ctx);

Ok(ctx)
}

fn create_services(loc: &IWbemLocator, path: &str) -> WMIResult<IWbemServices> {
debug!("Calling ConnectServer");

Expand Down
93 changes: 67 additions & 26 deletions src/context.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
use std::collections::HashMap;

use serde::Serialize;
use windows_core::{BSTR, VARIANT};

use crate::{WMIConnection, WMIResult};
use log::debug;
use windows::Win32::System::{
Com::{CoCreateInstance, CLSCTX_INPROC_SERVER},
Wmi::{IWbemContext, WbemContext},
};
use windows_core::{BSTR, VARIANT};

#[derive(Debug, PartialEq, Serialize, Clone)]
#[serde(untagged)]
#[derive(Debug, PartialEq, Clone)]
#[non_exhaustive]
pub enum ContextValueType {
String(String),
I4(i32),
Expand All @@ -25,26 +26,43 @@ impl From<ContextValueType> for VARIANT {
}
}

impl WMIConnection {
/// Sets the specified named context values for use in providing additional context information to queries.
#[derive(Clone, Debug)]
pub struct WMIContext(pub(crate) IWbemContext);

impl WMIContext {
/// Creates a new instances of [`WMIContext`]
pub(crate) fn new() -> WMIResult<WMIContext> {
debug!("Calling CoCreateInstance for CLSID_WbemContext");

let ctx = unsafe { CoCreateInstance(&WbemContext, None, CLSCTX_INPROC_SERVER)? };

debug!("Got context {:?}", ctx);

Ok(WMIContext(ctx))
}

/// Sets the specified named context value for use in providing additional context information to queries.
///
/// Note the context values will persist across subsequent queries until [`WMIConnection::clear_ctx_values`] is called.
pub fn set_ctx_values(
&mut self,
ctx_values: HashMap<String, ContextValueType>,
) -> WMIResult<()> {
for (k, v) in ctx_values {
let key = BSTR::from(k);
let value = v.clone().into();
unsafe { self.ctx.SetValue(&key, 0, &value)? };
}
/// Note the context values will persist across subsequent queries until [`WMIConnection::delete_all`] is called.
pub fn set_value(&mut self, key: &str, value: impl Into<ContextValueType>) -> WMIResult<()> {
let value = value.into();
unsafe { self.0.SetValue(&BSTR::from(key), 0, &value.into())? };

Ok(())
}

/// Clears all named values from the underlying context object.
pub fn clear_ctx_values(&mut self) -> WMIResult<()> {
unsafe { self.ctx.DeleteAll().map_err(Into::into) }
pub fn delete_all(&mut self) -> WMIResult<()> {
unsafe { self.0.DeleteAll()? };

Ok(())
}
}

impl WMIConnection {
/// Returns a mutable reference to the [`WMIContext`] object
pub fn ctx(&mut self) -> &mut WMIContext {
&mut self.ctx
}
}

Expand Down Expand Up @@ -86,18 +104,41 @@ mod tests {
let mut orig_adapters = wmi_con.query::<MSFT_NetAdapter>().unwrap();
assert!(!orig_adapters.is_empty());

let mut ctx_values = HashMap::new();
ctx_values.insert("IncludeHidden".into(), true.into());
wmi_con.set_ctx_values(ctx_values).unwrap();

// With 'IncludeHidden' set to 'true', expect the response to contain additional adapters
wmi_con.ctx().set_value("IncludeHidden", true).unwrap();
let all_adapters = wmi_con.query::<MSFT_NetAdapter>().unwrap();
assert!(all_adapters.len() > orig_adapters.len());

wmi_con.clear_ctx_values().unwrap();
wmi_con.ctx().delete_all().unwrap();
let mut adapters = wmi_con.query::<MSFT_NetAdapter>().unwrap();
adapters.sort();
orig_adapters.sort();
assert_eq!(adapters, orig_adapters);
}

#[tokio::test]
async fn async_verify_ctx_values_used() {
let com_con = COMLibrary::new().unwrap();
let mut wmi_con =
WMIConnection::with_namespace_path("ROOT\\StandardCimv2", com_con).unwrap();

#[derive(Deserialize, PartialEq, Eq, PartialOrd, Ord, Debug)]
struct MSFT_NetAdapter {
InterfaceName: String,
}

let mut orig_adapters = wmi_con.async_query::<MSFT_NetAdapter>().await.unwrap();
assert!(!orig_adapters.is_empty());

// With 'IncludeHidden' set to 'true', expect the response to contain additional adapters
wmi_con.ctx().set_value("IncludeHidden", true).unwrap();
let all_adapters = wmi_con.async_query::<MSFT_NetAdapter>().await.unwrap();
assert!(all_adapters.len() > orig_adapters.len());

wmi_con.ctx().delete_all().unwrap();
let mut adapters = wmi_con.async_query::<MSFT_NetAdapter>().await.unwrap();
adapters.sort();
orig_adapters.sort();
assert_eq!(adapters, orig_adapters);
}
}
4 changes: 2 additions & 2 deletions src/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ impl WMIConnection {
&query_language,
&query,
WBEM_FLAG_FORWARD_ONLY | WBEM_FLAG_RETURN_IMMEDIATELY,
&self.ctx,
&self.ctx.0,
)?
};

Expand Down Expand Up @@ -431,7 +431,7 @@ impl WMIConnection {
self.svc.GetObject(
&object_path,
WBEM_FLAG_RETURN_WBEM_COMPLETE,
None,
&self.ctx.0,
Some(&mut pcls_obj),
None,
)?;
Expand Down

0 comments on commit fce8f21

Please sign in to comment.