Skip to content

Commit

Permalink
Fix model functions using SerializeIter (#2955)
Browse files Browse the repository at this point in the history
When passing `SerializeIter` in from outside, it counts as passing `!Send`
over an await point and therefore making these functions unusable.

This fixes that by moving the `SerializeIter` usage inside of `Http`, or
replacing with a slice when necessary.
  • Loading branch information
GnomedDev authored Aug 26, 2024
1 parent 0c27128 commit 849daf3
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 37 deletions.
35 changes: 31 additions & 4 deletions src/http/client.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#![allow(clippy::missing_errors_doc)]

use std::borrow::Cow;
use std::cell::Cell;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;

Expand All @@ -13,6 +14,7 @@ use reqwest::Url;
use reqwest::{Client, ClientBuilder, Response as ReqwestResponse, StatusCode};
use secrecy::{ExposeSecret as _, Secret};
use serde::de::DeserializeOwned;
use serde::ser::SerializeSeq as _;
use serde_json::{from_value, json, to_string, to_vec};
use tracing::{debug, trace};

Expand Down Expand Up @@ -65,6 +67,29 @@ impl secrecy::Zeroize for Token {
impl secrecy::CloneableSecret for Token {}
impl secrecy::DebugSecret for Token {}

// NOTE: This cannot be passed in from outside, due to `Cell` being !Send.
struct SerializeIter<I>(Cell<Option<I>>);

impl<I> SerializeIter<I> {
pub fn new(iter: I) -> Self {
Self(Cell::new(Some(iter)))
}
}

impl<Iter, Item> serde::Serialize for SerializeIter<Iter>
where
Iter: Iterator<Item = Item>,
Item: serde::Serialize,
{
fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
let Some(iter) = self.0.take() else {
return serializer.serialize_seq(Some(0))?.end();
};

serializer.collect_seq(iter)
}
}

/// A builder for the underlying [`Http`] client that performs requests to Discord's HTTP API. If
/// you do not need to use a proxy or do not need to disable the rate limiter, you can use
/// [`Http::new`] instead.
Expand Down Expand Up @@ -1671,9 +1696,9 @@ impl Http {
pub async fn edit_guild_channel_positions(
&self,
guild_id: GuildId,
value: &impl serde::Serialize,
value: impl Iterator<Item: serde::Serialize>,
) -> Result<()> {
let body = to_vec(value)?;
let body = to_vec(&SerializeIter::new(value))?;

self.wind(204, Request {
body: Some(body),
Expand Down Expand Up @@ -2010,12 +2035,14 @@ impl Http {
pub async fn edit_role_positions(
&self,
guild_id: GuildId,
map: &impl serde::Serialize,
positions: impl Iterator<Item: serde::Serialize>,
audit_log_reason: Option<&str>,
) -> Result<Vec<Role>> {
let body = to_vec(&SerializeIter::new(positions))?;

let mut value: Value = self
.fire(Request {
body: Some(to_vec(&map)?),
body: Some(body),
multipart: None,
headers: audit_log_reason.map(reason_into_header),
method: LightMethod::Patch,
Expand Down
13 changes: 6 additions & 7 deletions src/model/guild/guild_id.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ use crate::http::{CacheHttp, Http, UserPagination};
#[cfg(feature = "model")]
use crate::internal::prelude::*;
use crate::model::error::Maximum;
use crate::model::guild::SerializeIter;
use crate::model::prelude::*;

#[cfg(feature = "model")]
Expand Down Expand Up @@ -219,18 +218,18 @@ impl GuildId {
pub async fn bulk_ban(
self,
http: &Http,
users: impl IntoIterator<Item = UserId>,
user_ids: &[UserId],
delete_message_seconds: u32,
reason: Option<&str>,
) -> Result<BulkBanResponse> {
#[derive(serde::Serialize)]
struct BulkBan<I> {
user_ids: I,
struct BulkBan<'a> {
user_ids: &'a [UserId],
delete_message_seconds: u32,
}

let map = BulkBan {
user_ids: SerializeIter::new(users.into_iter()),
user_ids,
delete_message_seconds,
};

Expand Down Expand Up @@ -844,7 +843,7 @@ impl GuildId {
position,
});

http.edit_role_positions(self, &SerializeIter::new(iter), reason).await
http.edit_role_positions(self, iter, reason).await
}

/// Edits the guild's welcome screen.
Expand Down Expand Up @@ -1198,7 +1197,7 @@ impl GuildId {
position,
});

http.edit_guild_channel_positions(self, &SerializeIter::new(iter)).await
http.edit_guild_channel_positions(self, iter).await
}

/// Returns a list of [`Member`]s in a [`Guild`] whose username or nickname starts with a
Expand Down
4 changes: 2 additions & 2 deletions src/model/guild/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -430,11 +430,11 @@ impl Guild {
pub async fn bulk_ban(
&self,
http: &Http,
users: impl IntoIterator<Item = UserId>,
user_ids: &[UserId],
delete_message_seconds: u32,
reason: Option<&str>,
) -> Result<BulkBanResponse> {
self.id.bulk_ban(http, users, delete_message_seconds, reason).await
self.id.bulk_ban(http, user_ids, delete_message_seconds, reason).await
}

/// Returns the formatted URL of the guild's banner image, if one exists.
Expand Down
24 changes: 0 additions & 24 deletions src/model/utils.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
use std::cell::Cell;
use std::fmt;

use arrayvec::ArrayVec;
use serde::de::Error as DeError;
use serde::ser::SerializeSeq;
use serde_cow::CowStr;
use small_fixed_array::FixedString;

Expand Down Expand Up @@ -70,28 +68,6 @@ where
remove_from_map_opt(map, key)?.ok_or_else(|| serde::de::Error::missing_field(key))
}

pub(super) struct SerializeIter<I>(Cell<Option<I>>);

impl<I> SerializeIter<I> {
pub fn new(iter: I) -> Self {
Self(Cell::new(Some(iter)))
}
}

impl<Iter, Item> serde::Serialize for SerializeIter<Iter>
where
Iter: Iterator<Item = Item>,
Item: serde::Serialize,
{
fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
let Some(iter) = self.0.take() else {
return serializer.serialize_seq(Some(0))?.end();
};

serializer.collect_seq(iter)
}
}

pub(super) enum StrOrInt<'de> {
String(String),
Str(&'de str),
Expand Down

0 comments on commit 849daf3

Please sign in to comment.