Skip to content

Commit

Permalink
Store RDP licenses in a LRU cache instead of a HashMap
Browse files Browse the repository at this point in the history
This ensures memory doesn't grow unbounded if we end up caching
a large number of licenses.
  • Loading branch information
zmb3 committed Oct 30, 2024
1 parent a553aa0 commit 6ded4b8
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 21 deletions.
41 changes: 40 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions lib/srv/desktop/rdp/rdpclient/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ rdp-rs = { git = "https://github.com/gravitational/rdp-rs", rev = "2b0d99cc60c7b
uuid = { version = "1.4.1", features = ["v4"] }
utf16string = "0.2.0"
png = "0.17.10"
lru = "0.12.5"

[build-dependencies]
cbindgen = "0.25.0"
Expand Down
65 changes: 45 additions & 20 deletions lib/srv/desktop/rdp/rdpclient/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,13 @@ extern crate log;
extern crate num_derive;

use errors::try_error;
use lru::LruCache;
use rand::Rng;
use rand::SeedableRng;
use rdp::core::event::*;
use rdp::core::gcc::KeyboardLayout;
use rdp::core::global;
use rdp::core::global::ServerError;
use rdp::core::license::MemoryLicenseStore;
use rdp::core::mcs;
use rdp::core::sec;
use rdp::core::tpkt;
Expand All @@ -74,6 +74,7 @@ use std::io::ErrorKind;
use std::io::{Cursor, Read, Write};
use std::net;
use std::net::{TcpStream, ToSocketAddrs};
use std::num::NonZeroUsize;
use std::os::raw::c_char;
use std::os::unix::io::AsRawFd;
use std::sync::OnceLock;
Expand Down Expand Up @@ -260,7 +261,7 @@ struct ConnectParams {
show_desktop_wallpaper: bool,
}

static LICENSE_STORE: OnceLock<SyncLicenseStore<MemoryLicenseStore>> = OnceLock::new();
static LICENSE_STORE: OnceLock<LruLicenseStore> = OnceLock::new();

fn connect_rdp_inner(go_ref: usize, params: ConnectParams) -> Result<Client, ConnectError> {
// Connect and authenticate.
Expand Down Expand Up @@ -312,8 +313,7 @@ fn connect_rdp_inner(go_ref: usize, params: ConnectParams) -> Result<Client, Con
performance_flags |= sec::ExtendedInfoFlag::PerfDisableWallpaper as u32;
}

let license_store =
LICENSE_STORE.get_or_init(|| SyncLicenseStore::new(MemoryLicenseStore::new()));
let license_store = LICENSE_STORE.get_or_init(|| LruLicenseStore::new());

sec::connect(
&mut mcs,
Expand Down Expand Up @@ -2063,22 +2063,24 @@ pub struct CGOSharedDirectoryListRequest {
pub path: *const c_char,
}

/// SyncLicenseStore protects an underlying LicenseStore with a mutex,
/// making it safe for concurrent use.
#[derive(Default)]
struct SyncLicenseStore<L> {
inner: Mutex<L>,
const MAX_SAVED_LICENSES: usize = 256;

/// LruLicenseStore stores licenses in an in-memory LRU cache.
struct LruLicenseStore {
licenses: Mutex<LruCache<LicenseStoreKey, Vec<u8>>>,
}

impl<L> SyncLicenseStore<L> {
pub fn new(license_store: L) -> Self {
impl LruLicenseStore {
pub fn new() -> Self {
Self {
inner: Mutex::new(license_store),
licenses: Mutex::new(LruCache::new(
NonZeroUsize::new(MAX_SAVED_LICENSES).unwrap(),
)),
}
}
}

impl<L: LicenseStore> LicenseStore for &SyncLicenseStore<L> {
impl LicenseStore for &LruLicenseStore {
fn write_license(
&mut self,
major: u16,
Expand All @@ -2089,10 +2091,16 @@ impl<L: LicenseStore> LicenseStore for &SyncLicenseStore<L> {
license: &[u8],
) {
info!("Saving {major}.{minor} license from {issuer}");
self.inner
.lock()
.unwrap()
.write_license(major, minor, company, issuer, product_id, license);
self.licenses.lock().unwrap().put(
LicenseStoreKey {
major,
minor,
company: company.to_owned(),
issuer: issuer.to_owned(),
product_id: product_id.to_owned(),
},
license.to_vec(),
);
}

fn read_license(
Expand All @@ -2104,20 +2112,37 @@ impl<L: LicenseStore> LicenseStore for &SyncLicenseStore<L> {
product_id: &str,
) -> Option<Vec<u8>> {
let license = self
.inner
.licenses
.lock()
.unwrap()
.read_license(major, minor, company, issuer, product_id);
.get(&LicenseStoreKey {
major,
minor,
company: company.to_owned(),
issuer: issuer.to_owned(),
product_id: product_id.to_owned(),
})
.cloned();

if license.is_some() {
info!("Found existing {major}.{minor} license from {issuer}");
} else {
info!("No existing {major}.{minor} license from {issuer}");
}

license
return license;
}
}

#[derive(PartialEq, Eq, Hash)]
struct LicenseStoreKey {
major: u16,
minor: u16,
company: String,
issuer: String,
product_id: String,
}

// These functions are defined on the Go side. Look for functions with '//export funcname'
// comments.
extern "C" {
Expand Down

0 comments on commit 6ded4b8

Please sign in to comment.