From 0e56247a6cf67c41af7785b1f007f6271e8fe040 Mon Sep 17 00:00:00 2001 From: Shrey Amin Date: Sat, 23 Nov 2024 14:07:17 -0500 Subject: [PATCH 1/2] Add support for setting IWbemContext Some queries require setting named context values using the IWbemContext interface. For example, MSFT_NetFirewallProfile returns the firewall status as configured by the local policy. To get the effective policy based on the local + group policies, one must specify the 'PolicyStore' as 'ActiveStore' using the context interface. --- src/async_query.rs | 2 +- src/connection.rs | 16 ++++++- src/context.rs | 103 +++++++++++++++++++++++++++++++++++++++++++++ src/lib.rs | 1 + src/query.rs | 4 +- 5 files changed, 122 insertions(+), 4 deletions(-) create mode 100644 src/context.rs diff --git a/src/async_query.rs b/src/async_query.rs index f05d59c..d668484 100644 --- a/src/async_query.rs +++ b/src/async_query.rs @@ -40,7 +40,7 @@ impl WMIConnection { &query_language, &query, WBEM_FLAG_BIDIRECTIONAL, - None, + &self.ctx, &p_sink_handle, )?; } diff --git a/src/connection.rs b/src/connection.rs index 8d907c6..9db5555 100644 --- a/src/connection.rs +++ b/src/connection.rs @@ -13,7 +13,8 @@ use windows::Win32::System::Com::{ }; use windows::Win32::System::Rpc::{RPC_C_AUTHN_WINNT, RPC_C_AUTHZ_NONE}; use windows::Win32::System::Wmi::{ - IWbemLocator, IWbemServices, WbemLocator, WBEM_FLAG_CONNECT_USE_MAX_WAIT, + IWbemContext, IWbemLocator, IWbemServices, WbemContext, WbemLocator, + WBEM_FLAG_CONNECT_USE_MAX_WAIT, }; /// A marker to indicate that the current thread was `CoInitialize`d. @@ -126,6 +127,7 @@ fn _test_com_lib_not_send(_s: impl Send) {} pub struct WMIConnection { _com_con: COMLibrary, pub svc: IWbemServices, + pub ctx: IWbemContext, } /// A connection to the local WMI provider, which provides querying capabilities. @@ -151,10 +153,12 @@ impl WMIConnection { pub fn with_namespace_path(namespace_path: &str, com_lib: COMLibrary) -> WMIResult { let loc = create_locator()?; let svc = create_services(&loc, namespace_path)?; + let ctx = create_context()?; let this = Self { _com_con: com_lib, svc, + ctx, }; this.set_proxy()?; @@ -191,6 +195,16 @@ fn create_locator() -> WMIResult { Ok(loc) } +fn create_context() -> WMIResult { + debug!("Calling CoCreateInstance for CLSID_WbemContext"); + + let ctx = unsafe { CoCreateInstance(&WbemContext, None, CLSCTX_INPROC_SERVER)? }; + + debug!("Got context {:?}", ctx); + + Ok(ctx) +} + fn create_services(loc: &IWbemLocator, path: &str) -> WMIResult { debug!("Calling ConnectServer"); diff --git a/src/context.rs b/src/context.rs new file mode 100644 index 0000000..b275fd1 --- /dev/null +++ b/src/context.rs @@ -0,0 +1,103 @@ +use std::collections::HashMap; + +use serde::Serialize; +use windows_core::{BSTR, VARIANT}; + +use crate::{WMIConnection, WMIResult}; + +#[derive(Debug, PartialEq, Serialize, Clone)] +#[serde(untagged)] +pub enum ContextValueType { + String(String), + I4(i32), + R8(f64), + Bool(bool), +} + +impl From for VARIANT { + fn from(value: ContextValueType) -> Self { + match value { + ContextValueType::Bool(b) => Self::from(b), + ContextValueType::I4(i4) => Self::from(i4), + ContextValueType::R8(r8) => Self::from(r8), + ContextValueType::String(str) => Self::from(BSTR::from(str)), + } + } +} + +impl WMIConnection { + /// Sets the specified named context values for use in providing additional context information to queries. + /// + /// Note the context values will persist across subsequent queries until [`WMIConnection::clear_ctx_values`] is called. + pub fn set_ctx_values( + &mut self, + ctx_values: HashMap, + ) -> WMIResult<()> { + for (k, v) in ctx_values { + let key = BSTR::from(k); + let value = v.clone().into(); + unsafe { self.ctx.SetValue(&key, 0, &value)? }; + } + + Ok(()) + } + + /// Clears all named values from the underlying context object. + pub fn clear_ctx_values(&mut self) -> WMIResult<()> { + unsafe { self.ctx.DeleteAll().map_err(Into::into) } + } +} + +macro_rules! impl_from_type { + ($target_type:ty, $variant:ident) => { + impl From<$target_type> for ContextValueType { + fn from(value: $target_type) -> Self { + Self::$variant(value.into()) + } + } + }; +} + +impl_from_type!(&str, String); +impl_from_type!(i32, I4); +impl_from_type!(f64, R8); +impl_from_type!(bool, Bool); + +#[allow(non_snake_case)] +#[allow(non_camel_case_types)] +#[allow(dead_code)] +#[cfg(test)] +mod tests { + use super::*; + use crate::COMLibrary; + use serde::Deserialize; + + #[test] + fn verify_ctx_values_used() { + let com_con = COMLibrary::new().unwrap(); + let mut wmi_con = + WMIConnection::with_namespace_path("ROOT\\StandardCimv2", com_con).unwrap(); + + #[derive(Deserialize, PartialEq, Eq, PartialOrd, Ord, Debug)] + struct MSFT_NetAdapter { + InterfaceName: String, + } + + let mut orig_adapters = wmi_con.query::().unwrap(); + assert!(!orig_adapters.is_empty()); + + let mut ctx_values = HashMap::new(); + ctx_values.insert("IncludeHidden".into(), true.into()); + wmi_con.set_ctx_values(ctx_values).unwrap(); + + // With 'IncludeHidden' set to 'true', expect the response to contain additional adapters + let all_adapters = wmi_con.query::().unwrap(); + assert!(all_adapters.len() > orig_adapters.len()); + + wmi_con.clear_ctx_values().unwrap(); + let mut adapters = wmi_con.query::().unwrap(); + adapters.sort(); + orig_adapters.sort(); + assert_eq!(adapters, orig_adapters); + } +} diff --git a/src/lib.rs b/src/lib.rs index 0beade0..82b13d3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -273,6 +273,7 @@ pub mod datetime; #[cfg(feature = "time")] mod datetime_time; +pub mod context; pub mod de; pub mod duration; pub mod query; diff --git a/src/query.rs b/src/query.rs index a93578e..4b2d747 100644 --- a/src/query.rs +++ b/src/query.rs @@ -279,7 +279,7 @@ impl WMIConnection { &query_language, &query, WBEM_FLAG_FORWARD_ONLY | WBEM_FLAG_RETURN_IMMEDIATELY, - None, + &self.ctx, )? }; @@ -536,7 +536,7 @@ impl WMIConnection { /// Query all the associators of type T of the given object. /// The `object_path` argument can be provided by querying an object wih it's `__Path` property. - /// `AssocClass` must be have the name as the conneting association class between the original object and the results. + /// `AssocClass` must be have the name as the connecting association class between the original object and the results. /// See for example. /// /// ```edition2018 From b0a933130531092014f11c45c96c494c60384060 Mon Sep 17 00:00:00 2001 From: Shrey Amin Date: Mon, 25 Nov 2024 15:08:09 -0500 Subject: [PATCH 2/2] Address PR feedback --- src/async_query.rs | 2 +- src/connection.rs | 18 ++------- src/context.rs | 93 +++++++++++++++++++++++++++++++++------------- src/query.rs | 4 +- 4 files changed, 74 insertions(+), 43 deletions(-) diff --git a/src/async_query.rs b/src/async_query.rs index d668484..99ed02d 100644 --- a/src/async_query.rs +++ b/src/async_query.rs @@ -40,7 +40,7 @@ impl WMIConnection { &query_language, &query, WBEM_FLAG_BIDIRECTIONAL, - &self.ctx, + &self.ctx.0, &p_sink_handle, )?; } diff --git a/src/connection.rs b/src/connection.rs index 9db5555..ff89c01 100644 --- a/src/connection.rs +++ b/src/connection.rs @@ -1,3 +1,4 @@ +use crate::context::WMIContext; use crate::utils::WMIResult; use crate::WMIError; use log::debug; @@ -13,8 +14,7 @@ use windows::Win32::System::Com::{ }; use windows::Win32::System::Rpc::{RPC_C_AUTHN_WINNT, RPC_C_AUTHZ_NONE}; use windows::Win32::System::Wmi::{ - IWbemContext, IWbemLocator, IWbemServices, WbemContext, WbemLocator, - WBEM_FLAG_CONNECT_USE_MAX_WAIT, + IWbemLocator, IWbemServices, WbemLocator, WBEM_FLAG_CONNECT_USE_MAX_WAIT, }; /// A marker to indicate that the current thread was `CoInitialize`d. @@ -127,7 +127,7 @@ fn _test_com_lib_not_send(_s: impl Send) {} pub struct WMIConnection { _com_con: COMLibrary, pub svc: IWbemServices, - pub ctx: IWbemContext, + pub(crate) ctx: WMIContext, } /// A connection to the local WMI provider, which provides querying capabilities. @@ -153,7 +153,7 @@ impl WMIConnection { pub fn with_namespace_path(namespace_path: &str, com_lib: COMLibrary) -> WMIResult { let loc = create_locator()?; let svc = create_services(&loc, namespace_path)?; - let ctx = create_context()?; + let ctx = WMIContext::new()?; let this = Self { _com_con: com_lib, @@ -195,16 +195,6 @@ fn create_locator() -> WMIResult { Ok(loc) } -fn create_context() -> WMIResult { - debug!("Calling CoCreateInstance for CLSID_WbemContext"); - - let ctx = unsafe { CoCreateInstance(&WbemContext, None, CLSCTX_INPROC_SERVER)? }; - - debug!("Got context {:?}", ctx); - - Ok(ctx) -} - fn create_services(loc: &IWbemLocator, path: &str) -> WMIResult { debug!("Calling ConnectServer"); diff --git a/src/context.rs b/src/context.rs index b275fd1..f89ce0c 100644 --- a/src/context.rs +++ b/src/context.rs @@ -1,12 +1,13 @@ -use std::collections::HashMap; - -use serde::Serialize; -use windows_core::{BSTR, VARIANT}; - use crate::{WMIConnection, WMIResult}; +use log::debug; +use windows::Win32::System::{ + Com::{CoCreateInstance, CLSCTX_INPROC_SERVER}, + Wmi::{IWbemContext, WbemContext}, +}; +use windows_core::{BSTR, VARIANT}; -#[derive(Debug, PartialEq, Serialize, Clone)] -#[serde(untagged)] +#[derive(Debug, Clone)] +#[non_exhaustive] pub enum ContextValueType { String(String), I4(i32), @@ -25,26 +26,43 @@ impl From for VARIANT { } } -impl WMIConnection { - /// Sets the specified named context values for use in providing additional context information to queries. +#[derive(Clone, Debug)] +pub struct WMIContext(pub(crate) IWbemContext); + +impl WMIContext { + /// Creates a new instances of [`WMIContext`] + pub(crate) fn new() -> WMIResult { + debug!("Calling CoCreateInstance for CLSID_WbemContext"); + + let ctx = unsafe { CoCreateInstance(&WbemContext, None, CLSCTX_INPROC_SERVER)? }; + + debug!("Got context {:?}", ctx); + + Ok(WMIContext(ctx)) + } + + /// Sets the specified named context value for use in providing additional context information to queries. /// - /// Note the context values will persist across subsequent queries until [`WMIConnection::clear_ctx_values`] is called. - pub fn set_ctx_values( - &mut self, - ctx_values: HashMap, - ) -> WMIResult<()> { - for (k, v) in ctx_values { - let key = BSTR::from(k); - let value = v.clone().into(); - unsafe { self.ctx.SetValue(&key, 0, &value)? }; - } + /// Note the context values will persist across subsequent queries until [`WMIConnection::delete_all`] is called. + pub fn set_value(&mut self, key: &str, value: impl Into) -> WMIResult<()> { + let value = value.into(); + unsafe { self.0.SetValue(&BSTR::from(key), 0, &value.into())? }; Ok(()) } /// Clears all named values from the underlying context object. - pub fn clear_ctx_values(&mut self) -> WMIResult<()> { - unsafe { self.ctx.DeleteAll().map_err(Into::into) } + pub fn delete_all(&mut self) -> WMIResult<()> { + unsafe { self.0.DeleteAll()? }; + + Ok(()) + } +} + +impl WMIConnection { + /// Returns a mutable reference to the [`WMIContext`] object + pub fn ctx(&mut self) -> &mut WMIContext { + &mut self.ctx } } @@ -86,18 +104,41 @@ mod tests { let mut orig_adapters = wmi_con.query::().unwrap(); assert!(!orig_adapters.is_empty()); - let mut ctx_values = HashMap::new(); - ctx_values.insert("IncludeHidden".into(), true.into()); - wmi_con.set_ctx_values(ctx_values).unwrap(); - // With 'IncludeHidden' set to 'true', expect the response to contain additional adapters + wmi_con.ctx().set_value("IncludeHidden", true).unwrap(); let all_adapters = wmi_con.query::().unwrap(); assert!(all_adapters.len() > orig_adapters.len()); - wmi_con.clear_ctx_values().unwrap(); + wmi_con.ctx().delete_all().unwrap(); let mut adapters = wmi_con.query::().unwrap(); adapters.sort(); orig_adapters.sort(); assert_eq!(adapters, orig_adapters); } + + #[tokio::test] + async fn async_verify_ctx_values_used() { + let com_con = COMLibrary::new().unwrap(); + let mut wmi_con = + WMIConnection::with_namespace_path("ROOT\\StandardCimv2", com_con).unwrap(); + + #[derive(Deserialize, PartialEq, Eq, PartialOrd, Ord, Debug)] + struct MSFT_NetAdapter { + InterfaceName: String, + } + + let mut orig_adapters = wmi_con.async_query::().await.unwrap(); + assert!(!orig_adapters.is_empty()); + + // With 'IncludeHidden' set to 'true', expect the response to contain additional adapters + wmi_con.ctx().set_value("IncludeHidden", true).unwrap(); + let all_adapters = wmi_con.async_query::().await.unwrap(); + assert!(all_adapters.len() > orig_adapters.len()); + + wmi_con.ctx().delete_all().unwrap(); + let mut adapters = wmi_con.async_query::().await.unwrap(); + adapters.sort(); + orig_adapters.sort(); + assert_eq!(adapters, orig_adapters); + } } diff --git a/src/query.rs b/src/query.rs index 4b2d747..cebf31d 100644 --- a/src/query.rs +++ b/src/query.rs @@ -279,7 +279,7 @@ impl WMIConnection { &query_language, &query, WBEM_FLAG_FORWARD_ONLY | WBEM_FLAG_RETURN_IMMEDIATELY, - &self.ctx, + &self.ctx.0, )? }; @@ -431,7 +431,7 @@ impl WMIConnection { self.svc.GetObject( &object_path, WBEM_FLAG_RETURN_WBEM_COMPLETE, - None, + &self.ctx.0, Some(&mut pcls_obj), None, )?;