Skip to content

Commit

Permalink
Fix the RDP licensing flow on v14
Browse files Browse the repository at this point in the history
Licenses are stored in-memory, so if the agent restarts it will be
forced to go through the "new license" flow for the first session.
  • Loading branch information
zmb3 committed Oct 24, 2024
1 parent f2d2c04 commit 667af23
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 20 deletions.
17 changes: 3 additions & 14 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion lib/srv/desktop/rdp/rdpclient/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ num-traits = "0.2.16"
rand = { version = "0.8.5", features = ["getrandom"] }
rand_chacha = "0.3.1"
rsa = "0.9.2"
rdp-rs = { git = "https://github.com/gravitational/rdp-rs", rev = "edfb5330a11d11eaf36d65e4300555368b4c6b02" }
#rdp-rs = { git = "https://github.com/gravitational/rdp-rs", rev = "edfb5330a11d11eaf36d65e4300555368b4c6b02" }
rdp-rs = { path = "/Users/zmb/src/rdp-rs" }
uuid = { version = "1.4.1", features = ["v4"] }
utf16string = "0.2.0"
png = "0.17.10"
Expand Down
9 changes: 6 additions & 3 deletions lib/srv/desktop/rdp/rdpclient/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -237,12 +237,14 @@ func (c *Client) connect(ctx context.Context) error {
return trace.Wrap(err)
}

// Addr and username strings only need to be valid for the duration of
// These strings only need to be valid for the duration of
// C.connect_rdp. They are copied on the Rust side and can be freed here.
addr := C.CString(c.cfg.Addr)
defer C.free(unsafe.Pointer(addr))
username := C.CString(c.username)
defer C.free(unsafe.Pointer(username))
hostID := C.CString(c.cfg.HostID)
defer C.free(unsafe.Pointer(hostID))

cert_der, err := utils.UnsafeSliceData(userCertDER)
if err != nil {
Expand All @@ -261,8 +263,9 @@ func (c *Client) connect(ctx context.Context) error {
res := C.connect_rdp(
C.uintptr_t(c.handle),
C.CGOConnectParams{
go_addr: addr,
go_username: username,
go_addr: addr,
go_username: username,
go_client_id: hostID,
// cert length and bytes.
cert_der_len: C.uint32_t(len(userCertDER)),
cert_der: (*C.uint8_t)(cert_der),
Expand Down
6 changes: 5 additions & 1 deletion lib/srv/desktop/rdp/rdpclient/client_common.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,14 @@ import (
type Config struct {
// Addr is the network address of the RDP server, in the form host:port.
Addr string
// UserCertGenerator generates user certificates for RDP authentication.

// GenerateUserCert generates user certificates for RDP authentication.
GenerateUserCert GenerateUserCertFn
CertTTL time.Duration

// HostID uniquely identifies the Teleport agent running the RDP client.
HostID string

// AuthorizeFn is called to authorize a user connecting to a Windows desktop.
AuthorizeFn func(login string) error

Expand Down
67 changes: 66 additions & 1 deletion lib/srv/desktop/rdp/rdpclient/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,12 @@ use rdp::core::event::*;
use rdp::core::gcc::KeyboardLayout;
use rdp::core::global;
use rdp::core::global::ServerError;
use rdp::core::license::MemoryLicenseStore;
use rdp::core::mcs;
use rdp::core::sec;
use rdp::core::tpkt;
use rdp::core::x224;
use rdp::core::LicenseStore;
use rdp::model::error::{Error as RdpError, RdpError as RdpProtocolError, RdpErrorKind, RdpResult};
use rdp::model::link::{Link, Stream};
use rdpdr::path::UnixPath;
Expand All @@ -74,6 +76,7 @@ use std::net;
use std::net::{TcpStream, ToSocketAddrs};
use std::os::raw::c_char;
use std::os::unix::io::AsRawFd;
use std::sync::OnceLock;
use std::sync::{Arc, Mutex};
use std::{mem, ptr, slice, time};

Expand Down Expand Up @@ -182,6 +185,7 @@ pub unsafe extern "C" fn connect_rdp(go_ref: usize, params: CGOConnectParams) ->
// Convert from C to Rust types.
let addr = from_c_string(params.go_addr);
let username = from_c_string(params.go_username);
let client_id = from_c_string(params.go_client_id);
let cert_der = from_go_array(params.cert_der, params.cert_der_len);
let key_der = from_go_array(params.key_der, params.key_der_len);

Expand All @@ -190,6 +194,7 @@ pub unsafe extern "C" fn connect_rdp(go_ref: usize, params: CGOConnectParams) ->
ConnectParams {
addr,
username,
client_id,
cert_der,
key_der,
screen_width: params.screen_width,
Expand Down Expand Up @@ -230,6 +235,7 @@ const RDPSND_CHANNEL_NAME: &str = "rdpsnd";
pub struct CGOConnectParams {
go_addr: *const c_char,
go_username: *const c_char,
go_client_id: *const c_char,
cert_der_len: u32,
cert_der: *mut u8,
key_der_len: u32,
Expand All @@ -244,6 +250,7 @@ pub struct CGOConnectParams {
struct ConnectParams {
addr: String,
username: String,
client_id: String,
cert_der: Vec<u8>,
key_der: Vec<u8>,
screen_width: u16,
Expand All @@ -253,6 +260,8 @@ struct ConnectParams {
show_desktop_wallpaper: bool,
}

static LICENSE_STORE: OnceLock<SyncLicenseStore<MemoryLicenseStore>> = OnceLock::new();

fn connect_rdp_inner(go_ref: usize, params: ConnectParams) -> Result<Client, ConnectError> {
// Connect and authenticate.
let addr = params
Expand Down Expand Up @@ -288,7 +297,7 @@ fn connect_rdp_inner(go_ref: usize, params: ConnectParams) -> Result<Client, Con
static_channels.push(cliprdr::CHANNEL_NAME.to_string())
}
mcs.connect(
"rdp-rs".to_string(),
params.client_id.clone(),
params.screen_width,
params.screen_height,
KeyboardLayout::US,
Expand All @@ -302,8 +311,13 @@ fn connect_rdp_inner(go_ref: usize, params: ConnectParams) -> Result<Client, Con
if !params.show_desktop_wallpaper {
performance_flags |= sec::ExtendedInfoFlag::PerfDisableWallpaper as u32;
}

let license_store =
LICENSE_STORE.get_or_init(|| SyncLicenseStore::new(MemoryLicenseStore::new()));

sec::connect(
&mut mcs,
&params.client_id,
&domain.to_string(),
&params.username,
&pin,
Expand All @@ -312,6 +326,7 @@ fn connect_rdp_inner(go_ref: usize, params: ConnectParams) -> Result<Client, Con
// which is known only to Teleport and unique for each RDP session.
Some(sec::InfoFlag::InfoPasswordIsScPin as u32 | sec::InfoFlag::InfoMouseHasWheel as u32),
Some(performance_flags),
license_store,
)?;
// Client for the "global" channel - video output and user input.
let global = global::Client::new(
Expand Down Expand Up @@ -2048,6 +2063,56 @@ pub struct CGOSharedDirectoryListRequest {
pub path: *const c_char,
}

/// SyncLicenseStore protects an underlying LicenseStore with a mutex,
/// making it safe for concurrent use.
#[derive(Default)]
struct SyncLicenseStore<L> {
inner: Mutex<L>,
}

impl<L> SyncLicenseStore<L> {
pub fn new(license_store: L) -> Self {
Self {
inner: Mutex::new(license_store),
}
}

pub fn into_inner(self) -> L {
self.inner.into_inner().unwrap()
}
}

impl<L: LicenseStore> LicenseStore for &SyncLicenseStore<L> {
fn write_license(
&mut self,
major: u16,
minor: u16,
company: &str,
issuer: &str,
product_id: &str,
license: &[u8],
) {
self.inner
.lock()
.unwrap()
.write_license(major, minor, company, issuer, product_id, license);
}

fn read_license(
&self,
major: u16,
minor: u16,
company: &str,
issuer: &str,
product_id: &str,
) -> Option<Vec<u8>> {
self.inner
.lock()
.unwrap()
.read_license(major, minor, company, issuer, product_id)
}
}

// These functions are defined on the Go side. Look for functions with '//export funcname'
// comments.
extern "C" {
Expand Down
1 change: 1 addition & 0 deletions lib/srv/desktop/windows_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -877,6 +877,7 @@ func (s *WindowsService) connectRDP(ctx context.Context, log logrus.FieldLogger,
},
CertTTL: windows.CertTTL,
Addr: addr.String(),
HostID: s.cfg.Heartbeat.HostUUID,
Conn: tdpConn,
AuthorizeFn: authorize,
AllowClipboard: authCtx.Checker.DesktopClipboard(),
Expand Down

0 comments on commit 667af23

Please sign in to comment.