Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes for backup transfer #6027

Merged
merged 5 commits into from
Oct 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ once_cell = { workspace = true }
parking_lot = "0.12"
percent-encoding = "2.3"
pgp = { version = "0.13.2", default-features = false }
pin-project = "1"
qrcodegen = "1.7.0"
quick-xml = "0.36"
quoted_printable = "0.5"
Expand Down
4 changes: 0 additions & 4 deletions src/blob.rs
Original file line number Diff line number Diff line change
Expand Up @@ -666,10 +666,6 @@ impl<'a> BlobDirContents<'a> {
pub(crate) fn iter(&self) -> BlobDirIter<'_> {
BlobDirIter::new(self.context, self.inner.iter())
}

pub(crate) fn len(&self) -> usize {
self.inner.len()
}
}

/// A iterator over all the [`BlobObject`]s in the blobdir.
Expand Down
174 changes: 146 additions & 28 deletions src/imex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@

use std::ffi::OsStr;
use std::path::{Path, PathBuf};
use std::pin::Pin;

use ::pgp::types::KeyTrait;
use anyhow::{bail, ensure, format_err, Context as _, Result};
use futures::TryStreamExt;
use futures_lite::FutureExt;
use pin_project::pin_project;

use tokio::fs::{self, File};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio_tar::Archive;

use crate::blob::BlobDirContents;
Expand Down Expand Up @@ -212,7 +215,7 @@ async fn imex_inner(
path.display()
);
ensure!(context.sql.is_open().await, "Database not opened.");
context.emit_event(EventType::ImexProgress(10));
context.emit_event(EventType::ImexProgress(1));
link2xt marked this conversation as resolved.
Show resolved Hide resolved

if what == ImexMode::ExportBackup || what == ImexMode::ExportSelfKeys {
// before we export anything, make sure the private key exists
Expand Down Expand Up @@ -294,42 +297,84 @@ pub(crate) async fn import_backup_stream<R: tokio::io::AsyncRead + Unpin>(
.0
}

/// Reader that emits progress events as bytes are read from it.
#[pin_project]
struct ProgressReader<R> {
/// Wrapped reader.
#[pin]
inner: R,

/// Number of bytes successfully read from the internal reader.
read: usize,
link2xt marked this conversation as resolved.
Show resolved Hide resolved

/// Total size of the backup .tar file expected to be read from the reader.
/// Used to calculate the progress.
file_size: usize,

/// Last progress emitted to avoid emitting the same progress value twice.
last_progress: usize,
link2xt marked this conversation as resolved.
Show resolved Hide resolved

/// Context for emitting progress events.
context: Context,
}

impl<R> ProgressReader<R> {
fn new(r: R, context: Context, file_size: u64) -> Self {
Self {
inner: r,
read: 0,
file_size: file_size as usize,
last_progress: 1,
context,
}
}
}

impl<R> AsyncRead for ProgressReader<R>
where
R: AsyncRead,
{
fn poll_read(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut ReadBuf<'_>,
) -> std::task::Poll<std::io::Result<()>> {
let this = self.project();
let before = buf.filled().len();
let res = this.inner.poll_read(cx, buf);
if let std::task::Poll::Ready(Ok(())) = res {
*this.read = this.read.saturating_add(buf.filled().len() - before);

let progress = std::cmp::min(1000 * *this.read / *this.file_size, 999);
if progress > *this.last_progress {
this.context.emit_event(EventType::ImexProgress(progress));
*this.last_progress = progress;
}
}
res
}
}

async fn import_backup_stream_inner<R: tokio::io::AsyncRead + Unpin>(
context: &Context,
backup_file: R,
file_size: u64,
passphrase: String,
) -> (Result<()>,) {
let backup_file = ProgressReader::new(backup_file, context.clone(), file_size);
let mut archive = Archive::new(backup_file);

let mut entries = match archive.entries() {
Ok(entries) => entries,
Err(e) => return (Err(e).context("Failed to get archive entries"),),
};
let mut blobs = Vec::new();
// We already emitted ImexProgress(10) above
let mut last_progress = 10;
const PROGRESS_MIGRATIONS: u128 = 999;
let mut total_size: u64 = 0;
let mut res: Result<()> = loop {
let mut f = match entries.try_next().await {
Ok(Some(f)) => f,
Ok(None) => break Ok(()),
Err(e) => break Err(e).context("Failed to get next entry"),
};
total_size += match f.header().entry_size() {
Ok(size) => size,
Err(e) => break Err(e).context("Failed to get entry size"),
};
let max = PROGRESS_MIGRATIONS - 1;
let progress = std::cmp::min(
max * u128::from(total_size) / std::cmp::max(u128::from(file_size), 1),
max,
);
if progress > last_progress {
context.emit_event(EventType::ImexProgress(progress as usize));
last_progress = progress;
}

let path = match f.path() {
Ok(path) => path.to_path_buf(),
Expand Down Expand Up @@ -379,7 +424,7 @@ async fn import_backup_stream_inner<R: tokio::io::AsyncRead + Unpin>(
.log_err(context)
.ok();
if res.is_ok() {
context.emit_event(EventType::ImexProgress(PROGRESS_MIGRATIONS as usize));
context.emit_event(EventType::ImexProgress(999));
res = context.sql.run_migrations(context).await;
}
if res.is_ok() {
Expand Down Expand Up @@ -452,41 +497,114 @@ async fn export_backup(context: &Context, dir: &Path, passphrase: String) -> Res

let file = File::create(&temp_path).await?;
let blobdir = BlobDirContents::new(context).await?;
export_backup_stream(context, &temp_db_path, blobdir, file)

let mut file_size = 0;
file_size += temp_db_path.metadata()?.len();
for blob in blobdir.iter() {
file_size += blob.to_abs_path().metadata()?.len()
}

export_backup_stream(context, &temp_db_path, blobdir, file, file_size)
.await
.context("Exporting backup to file failed")?;
fs::rename(temp_path, &dest_path).await?;
context.emit_event(EventType::ImexFileWritten(dest_path));
Ok(())
}

/// Writer that emits progress events as bytes are written into it.
#[pin_project]
struct ProgressWriter<W> {
/// Wrapped writer.
#[pin]
inner: W,

/// Number of bytes successfully written into the internal writer.
written: usize,

/// Total size of the backup .tar file expected to be written into the writer.
/// Used to calculate the progress.
file_size: usize,

/// Last progress emitted to avoid emitting the same progress value twice.
last_progress: usize,

/// Context for emitting progress events.
context: Context,
}

impl<W> ProgressWriter<W> {
fn new(w: W, context: Context, file_size: u64) -> Self {
Self {
inner: w,
written: 0,
file_size: file_size as usize,
last_progress: 1,
context,
}
}
}

impl<W> AsyncWrite for ProgressWriter<W>
where
W: AsyncWrite,
{
fn poll_write(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<Result<usize, std::io::Error>> {
let this = self.project();
let res = this.inner.poll_write(cx, buf);
if let std::task::Poll::Ready(Ok(written)) = res {
*this.written = this.written.saturating_add(written);

let progress = std::cmp::min(1000 * *this.written / *this.file_size, 999);
if progress > *this.last_progress {
this.context.emit_event(EventType::ImexProgress(progress));
*this.last_progress = progress;
}
}
res
}

fn poll_flush(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
self.project().inner.poll_flush(cx)
}

fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), std::io::Error>> {
self.project().inner.poll_shutdown(cx)
}
}

/// Exports the database and blobs into a stream.
pub(crate) async fn export_backup_stream<'a, W>(
context: &'a Context,
temp_db_path: &Path,
blobdir: BlobDirContents<'a>,
writer: W,
file_size: u64,
) -> Result<()>
where
W: tokio::io::AsyncWrite + tokio::io::AsyncWriteExt + Unpin + Send + 'static,
{
let writer = ProgressWriter::new(writer, context.clone(), file_size);
let mut builder = tokio_tar::Builder::new(writer);

builder
.append_path_with_name(temp_db_path, DBFILE_BACKUP_NAME)
.await?;

let mut last_progress = 10;

for (i, blob) in blobdir.iter().enumerate() {
for blob in blobdir.iter() {
let mut file = File::open(blob.to_abs_path()).await?;
let path_in_archive = PathBuf::from(BLOBS_BACKUP_NAME).join(blob.as_name());
builder.append_file(path_in_archive, &mut file).await?;
let progress = std::cmp::min(1000 * i / blobdir.len(), 999);
if progress > last_progress {
context.emit_event(EventType::ImexProgress(progress));
last_progress = progress;
}
}

builder.finish().await?;
Expand Down
39 changes: 34 additions & 5 deletions src/imex/transfer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ use std::pin::Pin;
use std::sync::Arc;
use std::task::Poll;

use anyhow::{bail, Context as _, Result};
use anyhow::{bail, format_err, Context as _, Result};
use futures_lite::FutureExt;
use iroh_net::relay::RelayMode;
use iroh_net::Endpoint;
use tokio::fs;
Expand Down Expand Up @@ -108,6 +109,7 @@ impl BackupProvider {
.get_blobdir()
.parent()
.context("Context dir not found")?;

let dbfile = context_dir.join(DBFILE_BACKUP_NAME);
if fs::metadata(&dbfile).await.is_ok() {
fs::remove_file(&dbfile).await?;
Expand All @@ -123,7 +125,6 @@ impl BackupProvider {
export_database(context, &dbfile, passphrase, time())
.await
.context("Database export failed")?;
context.emit_event(EventType::ImexProgress(300));

let drop_token = CancellationToken::new();
let handle = {
Expand Down Expand Up @@ -177,6 +178,7 @@ impl BackupProvider {
}

info!(context, "Received valid backup authentication token.");
context.emit_event(EventType::ImexProgress(1));

let blobdir = BlobDirContents::new(&context).await?;

Expand All @@ -188,7 +190,7 @@ impl BackupProvider {

send_stream.write_all(&file_size.to_be_bytes()).await?;

export_backup_stream(&context, &dbfile, blobdir, send_stream)
export_backup_stream(&context, &dbfile, blobdir, send_stream, file_size)
.await
.context("Failed to write backup into QUIC stream")?;
info!(context, "Finished writing backup into QUIC stream.");
Expand Down Expand Up @@ -231,8 +233,20 @@ impl BackupProvider {
let context = context.clone();
let auth_token = auth_token.clone();
let dbfile = dbfile.clone();
if let Err(err) = Self::handle_connection(context.clone(), conn, auth_token, dbfile).await {
if let Err(err) = Self::handle_connection(context.clone(), conn, auth_token, dbfile).race(
async {
cancel_token.recv().await.ok();
Err(format_err!("Backup transfer cancelled"))
}
).race(
async {
drop_token.cancelled().await;
Err(format_err!("Backup provider dropped"))
}
).await {
warn!(context, "Error while handling backup connection: {err:#}.");
Copy link
Collaborator

@Hocuri Hocuri Oct 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, if the transfer is cancelled during endpoint.accept(..), we do context.emit_event(EventType::ImexProgress(0)), and write nothing to the log.

If the transfer is cancelled during handle_connection(..), we write a message to the log, don't emit an event, and continue with the next iteration of the loop, (i.e. try to call endpoint.accept(..)) again. (this was fixed)

This seems surprising?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed by emitting progress 0. Cancellation during transfer is still logged as a warning unlike cancellation of accept loop, but it is just logging and not shown in the UI anyway.

context.emit_event(EventType::ImexProgress(0));
break;
} else {
info!(context, "Backup transfer finished successfully.");
break;
Expand All @@ -242,10 +256,12 @@ impl BackupProvider {
}
},
_ = cancel_token.recv() => {
info!(context, "Backup transfer cancelled by the user, stopping accept loop.");
context.emit_event(EventType::ImexProgress(0));
break;
}
_ = drop_token.cancelled() => {
info!(context, "Backup transfer cancelled by dropping the provider, stopping accept loop.");
context.emit_event(EventType::ImexProgress(0));
break;
}
Expand Down Expand Up @@ -329,7 +345,20 @@ pub async fn get_backup(context: &Context, qr: Qr) -> Result<()> {
Qr::Backup2 {
node_addr,
auth_token,
} => get_backup2(context, node_addr, auth_token).await?,
} => {
let cancel_token = context.alloc_ongoing().await?;
let res = get_backup2(context, node_addr, auth_token)
.race(async {
cancel_token.recv().await.ok();
Err(format_err!("Backup reception cancelled"))
})
.await;
if res.is_err() {
context.emit_event(EventType::ImexProgress(0));
}
context.free_ongoing().await;
res?;
}
_ => bail!("QR code for backup must be of type DCBACKUP2"),
}
Ok(())
Expand Down
Loading