diff --git a/src/async_query.rs b/src/async_query.rs index d668484..99ed02d 100644 --- a/src/async_query.rs +++ b/src/async_query.rs @@ -40,7 +40,7 @@ impl WMIConnection { &query_language, &query, WBEM_FLAG_BIDIRECTIONAL, - &self.ctx, + &self.ctx.0, &p_sink_handle, )?; } diff --git a/src/connection.rs b/src/connection.rs index 9db5555..ff89c01 100644 --- a/src/connection.rs +++ b/src/connection.rs @@ -1,3 +1,4 @@ +use crate::context::WMIContext; use crate::utils::WMIResult; use crate::WMIError; use log::debug; @@ -13,8 +14,7 @@ use windows::Win32::System::Com::{ }; use windows::Win32::System::Rpc::{RPC_C_AUTHN_WINNT, RPC_C_AUTHZ_NONE}; use windows::Win32::System::Wmi::{ - IWbemContext, IWbemLocator, IWbemServices, WbemContext, WbemLocator, - WBEM_FLAG_CONNECT_USE_MAX_WAIT, + IWbemLocator, IWbemServices, WbemLocator, WBEM_FLAG_CONNECT_USE_MAX_WAIT, }; /// A marker to indicate that the current thread was `CoInitialize`d. @@ -127,7 +127,7 @@ fn _test_com_lib_not_send(_s: impl Send) {} pub struct WMIConnection { _com_con: COMLibrary, pub svc: IWbemServices, - pub ctx: IWbemContext, + pub(crate) ctx: WMIContext, } /// A connection to the local WMI provider, which provides querying capabilities. @@ -153,7 +153,7 @@ impl WMIConnection { pub fn with_namespace_path(namespace_path: &str, com_lib: COMLibrary) -> WMIResult { let loc = create_locator()?; let svc = create_services(&loc, namespace_path)?; - let ctx = create_context()?; + let ctx = WMIContext::new()?; let this = Self { _com_con: com_lib, @@ -195,16 +195,6 @@ fn create_locator() -> WMIResult { Ok(loc) } -fn create_context() -> WMIResult { - debug!("Calling CoCreateInstance for CLSID_WbemContext"); - - let ctx = unsafe { CoCreateInstance(&WbemContext, None, CLSCTX_INPROC_SERVER)? }; - - debug!("Got context {:?}", ctx); - - Ok(ctx) -} - fn create_services(loc: &IWbemLocator, path: &str) -> WMIResult { debug!("Calling ConnectServer"); diff --git a/src/context.rs b/src/context.rs index b275fd1..695ac40 100644 --- a/src/context.rs +++ b/src/context.rs @@ -1,12 +1,13 @@ -use std::collections::HashMap; - -use serde::Serialize; -use windows_core::{BSTR, VARIANT}; - use crate::{WMIConnection, WMIResult}; +use log::debug; +use windows::Win32::System::{ + Com::{CoCreateInstance, CLSCTX_INPROC_SERVER}, + Wmi::{IWbemContext, WbemContext}, +}; +use windows_core::{BSTR, VARIANT}; -#[derive(Debug, PartialEq, Serialize, Clone)] -#[serde(untagged)] +#[derive(Debug, PartialEq, Clone)] +#[non_exhaustive] pub enum ContextValueType { String(String), I4(i32), @@ -25,26 +26,43 @@ impl From for VARIANT { } } -impl WMIConnection { - /// Sets the specified named context values for use in providing additional context information to queries. +#[derive(Clone, Debug)] +pub struct WMIContext(pub(crate) IWbemContext); + +impl WMIContext { + /// Creates a new instances of [`WMIContext`] + pub(crate) fn new() -> WMIResult { + debug!("Calling CoCreateInstance for CLSID_WbemContext"); + + let ctx = unsafe { CoCreateInstance(&WbemContext, None, CLSCTX_INPROC_SERVER)? }; + + debug!("Got context {:?}", ctx); + + Ok(WMIContext(ctx)) + } + + /// Sets the specified named context value for use in providing additional context information to queries. /// - /// Note the context values will persist across subsequent queries until [`WMIConnection::clear_ctx_values`] is called. - pub fn set_ctx_values( - &mut self, - ctx_values: HashMap, - ) -> WMIResult<()> { - for (k, v) in ctx_values { - let key = BSTR::from(k); - let value = v.clone().into(); - unsafe { self.ctx.SetValue(&key, 0, &value)? }; - } + /// Note the context values will persist across subsequent queries until [`WMIConnection::delete_all`] is called. + pub fn set_value(&mut self, key: &str, value: impl Into) -> WMIResult<()> { + let value = value.into(); + unsafe { self.0.SetValue(&BSTR::from(key), 0, &value.into())? }; Ok(()) } /// Clears all named values from the underlying context object. - pub fn clear_ctx_values(&mut self) -> WMIResult<()> { - unsafe { self.ctx.DeleteAll().map_err(Into::into) } + pub fn delete_all(&mut self) -> WMIResult<()> { + unsafe { self.0.DeleteAll()? }; + + Ok(()) + } +} + +impl WMIConnection { + /// Returns a mutable reference to the [`WMIContext`] object + pub fn ctx(&mut self) -> &mut WMIContext { + &mut self.ctx } } @@ -86,18 +104,41 @@ mod tests { let mut orig_adapters = wmi_con.query::().unwrap(); assert!(!orig_adapters.is_empty()); - let mut ctx_values = HashMap::new(); - ctx_values.insert("IncludeHidden".into(), true.into()); - wmi_con.set_ctx_values(ctx_values).unwrap(); - // With 'IncludeHidden' set to 'true', expect the response to contain additional adapters + wmi_con.ctx().set_value("IncludeHidden", true).unwrap(); let all_adapters = wmi_con.query::().unwrap(); assert!(all_adapters.len() > orig_adapters.len()); - wmi_con.clear_ctx_values().unwrap(); + wmi_con.ctx().delete_all().unwrap(); let mut adapters = wmi_con.query::().unwrap(); adapters.sort(); orig_adapters.sort(); assert_eq!(adapters, orig_adapters); } + + #[tokio::test] + async fn async_verify_ctx_values_used() { + let com_con = COMLibrary::new().unwrap(); + let mut wmi_con = + WMIConnection::with_namespace_path("ROOT\\StandardCimv2", com_con).unwrap(); + + #[derive(Deserialize, PartialEq, Eq, PartialOrd, Ord, Debug)] + struct MSFT_NetAdapter { + InterfaceName: String, + } + + let mut orig_adapters = wmi_con.async_query::().await.unwrap(); + assert!(!orig_adapters.is_empty()); + + // With 'IncludeHidden' set to 'true', expect the response to contain additional adapters + wmi_con.ctx().set_value("IncludeHidden", true).unwrap(); + let all_adapters = wmi_con.async_query::().await.unwrap(); + assert!(all_adapters.len() > orig_adapters.len()); + + wmi_con.ctx().delete_all().unwrap(); + let mut adapters = wmi_con.async_query::().await.unwrap(); + adapters.sort(); + orig_adapters.sort(); + assert_eq!(adapters, orig_adapters); + } } diff --git a/src/query.rs b/src/query.rs index 4b2d747..cebf31d 100644 --- a/src/query.rs +++ b/src/query.rs @@ -279,7 +279,7 @@ impl WMIConnection { &query_language, &query, WBEM_FLAG_FORWARD_ONLY | WBEM_FLAG_RETURN_IMMEDIATELY, - &self.ctx, + &self.ctx.0, )? }; @@ -431,7 +431,7 @@ impl WMIConnection { self.svc.GetObject( &object_path, WBEM_FLAG_RETURN_WBEM_COMPLETE, - None, + &self.ctx.0, Some(&mut pcls_obj), None, )?;