Skip to content

Commit

Permalink
Merge pull request #66 from JustRustThings/fix-cancel-async-notification
Browse files Browse the repository at this point in the history
feat: properly handle cancellation of async notification
  • Loading branch information
ohadravid authored Feb 28, 2023
2 parents fbcd53c + 5dc906f commit d1e5c01
Show file tree
Hide file tree
Showing 4 changed files with 181 additions and 95 deletions.
6 changes: 3 additions & 3 deletions src/async_query.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::{
query_sink::{AsyncQueryResultStream, IWbemObjectSink, QuerySink},
query_sink::{AsyncQueryResultStream, IWbemObjectSink, QuerySink, AsyncQueryResultStreamInner},
query::{FilterValue, build_query},
result_enumerator::IWbemClassWrapper,
connection::WMIConnection,
Expand Down Expand Up @@ -28,7 +28,7 @@ impl WMIConnection {
let query_language = BStr::from_str("WQL")?;
let query = BStr::from_str(query.as_ref())?;

let stream = AsyncQueryResultStream::new();
let stream = AsyncQueryResultStreamInner::new();
// The internal RefCount has initial value = 1.
let p_sink: ClassAllocation<QuerySink> = QuerySink::allocate(stream.clone());
let p_sink_handle = IWbemObjectSink::from(&**p_sink);
Expand All @@ -45,7 +45,7 @@ impl WMIConnection {
))?;
}

Ok(stream)
Ok(AsyncQueryResultStream::new(stream, self.clone(), p_sink))
}

/// Async version of [`raw_query`](WMIConnection#method.raw_query)
Expand Down
162 changes: 84 additions & 78 deletions src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,25 +114,14 @@ fn _test_com_lib_not_send(_s: impl Send) {}

pub struct WMIConnection {
_com_con: COMLibrary,
p_loc: Option<NonNull<IWbemLocator>>,
p_svc: Option<NonNull<IWbemServices>>,
p_svc: NonNull<IWbemServices>,
}

/// A connection to the local WMI provider, which provides querying capabilities.
///
/// Currently does not support remote providers (e.g connecting to other computers).
///
impl WMIConnection {
fn create_and_set_proxy(&mut self, namespace_path: Option<&str>) -> WMIResult<()> {
self.create_locator()?;

self.create_services(namespace_path.unwrap_or("ROOT\\CIMV2"))?;

self.set_proxy()?;

Ok(())
}

/// Creates a connection with a default `CIMV2` namespace path.
pub fn new(com_lib: COMLibrary) -> WMIResult<Self> {
Self::with_namespace_path("ROOT\\CIMV2", com_lib)
Expand All @@ -152,106 +141,123 @@ impl WMIConnection {
namespace_path: &str,
com_lib: COMLibrary,
) -> WMIResult<Self> {
let mut instance = Self {
let p_loc = create_locator()?;
let p_svc = create_services(p_loc.0.as_ptr(), namespace_path)?;

let this = Self {
_com_con: com_lib,
p_loc: None,
p_svc: None,
p_svc,
};

instance.create_and_set_proxy(Some(namespace_path))?;

Ok(instance)
this.set_proxy()?;
Ok(this)
}

pub fn svc(&self) -> *mut IWbemServices {
self.p_svc.unwrap().as_ptr()
self.p_svc.as_ptr()
}

fn loc(&self) -> *mut IWbemLocator {
self.p_loc.unwrap().as_ptr()
}

fn create_locator(&mut self) -> WMIResult<()> {
debug!("Calling CoCreateInstance for CLSID_WbemLocator");

let mut p_loc = NULL;
fn set_proxy(&self) -> WMIResult<()> {
debug!("Calling CoSetProxyBlanket");

unsafe {
check_hres(CoCreateInstance(
&CLSID_WbemLocator,
ptr::null_mut(),
CLSCTX_INPROC_SERVER,
&IID_IWbemLocator,
&mut p_loc,
check_hres(CoSetProxyBlanket(
self.svc() as _, // Indicates the proxy to set
RPC_C_AUTHN_WINNT, // RPC_C_AUTHN_xxx
RPC_C_AUTHZ_NONE, // RPC_C_AUTHZ_xxx
ptr::null_mut(), // Server principal name
RPC_C_AUTHN_LEVEL_CALL, // RPC_C_AUTHN_LEVEL_xxx
RPC_C_IMP_LEVEL_IMPERSONATE, // RPC_C_IMP_LEVEL_xxx
NULL, // client identity
EOAC_NONE, // proxy capabilities
))?;
}

self.p_loc = NonNull::new(p_loc as *mut IWbemLocator);
Ok(())
}
}

fn create_locator() -> WMIResult<WbemLocator> {
debug!("Calling CoCreateInstance for CLSID_WbemLocator");

debug!("Got locator {:?}", self.p_loc);
let mut p_loc = NULL;

Ok(())
unsafe {
check_hres(CoCreateInstance(
&CLSID_WbemLocator,
ptr::null_mut(),
CLSCTX_INPROC_SERVER,
&IID_IWbemLocator,
&mut p_loc,
))?;
}

fn create_services(&mut self, path: &str) -> WMIResult<()> {
debug!("Calling ConnectServer");
let p_loc = NonNull::new(p_loc as *mut IWbemLocator).unwrap();

let mut p_svc = ptr::null_mut::<IWbemServices>();
debug!("Got locator {:?}", p_loc);

let object_path_bstr = BStr::from_str(path)?;
Ok(WbemLocator(p_loc))
}

unsafe {
check_hres((*self.loc()).ConnectServer(
object_path_bstr.as_bstr(),
ptr::null_mut(),
ptr::null_mut(),
ptr::null_mut(),
WBEM_FLAG_CONNECT_USE_MAX_WAIT as _,
ptr::null_mut(),
ptr::null_mut(),
&mut p_svc,
))?;
}
fn create_services(loc: *const IWbemLocator, path: &str) -> WMIResult<NonNull<IWbemServices>> {
debug!("Calling ConnectServer");

self.p_svc = NonNull::new(p_svc as *mut IWbemServices);
let mut p_svc = ptr::null_mut::<IWbemServices>();

debug!("Got service {:?}", self.p_svc);
let object_path_bstr = BStr::from_str(path)?;

Ok(())
unsafe {
check_hres((*loc).ConnectServer(
object_path_bstr.as_bstr(),
ptr::null_mut(),
ptr::null_mut(),
ptr::null_mut(),
WBEM_FLAG_CONNECT_USE_MAX_WAIT as _,
ptr::null_mut(),
ptr::null_mut(),
&mut p_svc,
))?;
}

fn set_proxy(&self) -> WMIResult<()> {
debug!("Calling CoSetProxyBlanket");
let p_svc = NonNull::new(p_svc as *mut IWbemServices).unwrap();

unsafe {
check_hres(CoSetProxyBlanket(
self.svc() as _, // Indicates the proxy to set
RPC_C_AUTHN_WINNT, // RPC_C_AUTHN_xxx
RPC_C_AUTHZ_NONE, // RPC_C_AUTHZ_xxx
ptr::null_mut(), // Server principal name
RPC_C_AUTHN_LEVEL_CALL, // RPC_C_AUTHN_LEVEL_xxx
RPC_C_IMP_LEVEL_IMPERSONATE, // RPC_C_IMP_LEVEL_xxx
NULL, // client identity
EOAC_NONE, // proxy capabilities
))?;
}
debug!("Got service {:?}", p_svc);

Ok(())
Ok(p_svc)
}

impl Clone for WMIConnection {
fn clone(&self) -> Self {
// Creates a copy of the pointer and calls
// [AddRef](https://docs.microsoft.com/en-us/windows/win32/api/unknwn/nf-unknwn-iunknown-addref)
// to increment Reference Count.
//
// # Safety
// See [Managing the lifetime of an object](https://docs.microsoft.com/en-us/windows/win32/learnwin32/managing-the-lifetime-of-an-object)
// and [Rules for managing Ref count](https://docs.microsoft.com/en-us/windows/win32/com/rules-for-managing-reference-counts)
unsafe { self.p_svc.as_ref().AddRef() };

Self {
_com_con: self._com_con,
p_svc: self.p_svc,
}
}
}

impl Drop for WMIConnection {
fn drop(&mut self) {
if let Some(svc) = self.p_svc {
unsafe {
(*svc.as_ptr()).Release();
}
unsafe {
(*self.p_svc.as_ptr()).Release();
}
}
}

struct WbemLocator(NonNull<IWbemLocator>);

if let Some(loc) = self.p_loc {
unsafe {
(*loc.as_ptr()).Release();
}
impl Drop for WbemLocator {
fn drop(&mut self) {
unsafe {
(*self.0.as_ptr()).Release();
}
}
}
6 changes: 3 additions & 3 deletions src/notification.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::{
query_sink::{AsyncQueryResultStream, QuerySink, IWbemObjectSink},
query_sink::{AsyncQueryResultStream, QuerySink, IWbemObjectSink, AsyncQueryResultStreamInner},
result_enumerator::{QueryResultEnumerator, IWbemClassWrapper},
bstr::BStr,
utils::check_hres,
Expand Down Expand Up @@ -153,7 +153,7 @@ impl WMIConnection {
let query_language = BStr::from_str("WQL")?;
let query = BStr::from_str(query.as_ref())?;

let stream = AsyncQueryResultStream::new();
let stream = AsyncQueryResultStreamInner::new();
// The internal RefCount has initial value = 1.
let p_sink: ClassAllocation<QuerySink> = QuerySink::allocate(stream.clone());
let p_sink_handle = IWbemObjectSink::from(&**p_sink);
Expand All @@ -170,7 +170,7 @@ impl WMIConnection {
))?;
}

Ok(stream)
Ok(AsyncQueryResultStream::new(stream, self.clone(), p_sink))
}

/// Async version of [`raw_notification`](WMIConnection#method.raw_notification)
Expand Down
Loading

0 comments on commit d1e5c01

Please sign in to comment.