diff --git a/crates/rattler_installs_packages/src/python_env/byte_code_compiler.rs b/crates/rattler_installs_packages/src/python_env/byte_code_compiler.rs index 1c72920e..65198b96 100644 --- a/crates/rattler_installs_packages/src/python_env/byte_code_compiler.rs +++ b/crates/rattler_installs_packages/src/python_env/byte_code_compiler.rs @@ -15,6 +15,7 @@ type CompilationResponse = Result; type CompilationRequest = PathBuf; type BoxedCallback = Box; +type CompilationCallbackMap = HashMap>; /// An error that can occur when compiling a source file. #[derive(Debug, Error, Clone)] @@ -57,7 +58,7 @@ pub struct ByteCodeCompiler { /// Callback functions per compilation request. These are called when the compilation host /// finishes processing a request. - pending_callbacks: Arc>>>, + pending_callbacks: Arc>>, /// The child process. This is waited upon on drop. child: Option, @@ -114,10 +115,7 @@ impl ByteCodeCompiler { // Spawn another thread to process the output of the compilation process and forward it to the // response channel. - let pending_callbacks = Arc::new(Mutex::new(HashMap::< - CompilationRequest, - Vec, - >::new())); + let pending_callbacks = Arc::new(Mutex::new(Some(CompilationCallbackMap::new()))); let response_callbacks = pending_callbacks.clone(); let child_stdout = BufReader::new(child.stdout.take().expect("stdout is piped")); std::thread::spawn(move || { @@ -132,8 +130,14 @@ impl ByteCodeCompiler { Ok(response) => { tracing::trace!("finished compiling '{}'", response.path.display()); - let mut callbacks = response_callbacks.lock(); - match callbacks.remove(&response.path) { + let callbacks = { + let mut callback_lock = response_callbacks.lock(); + let callbacks = callback_lock.as_mut().expect( + "the callbacks are not dropped until the end of this function", + ); + callbacks.remove(&response.path) + }; + match callbacks { None => panic!( "received a response for an unknown request '{}'", response.path.display() @@ -157,6 +161,19 @@ impl ByteCodeCompiler { } } } + + tracing::trace!("compilation host stdout closed"); + + // Abort any pending callbacks and disable the ability to add new ones. + let callbacks = response_callbacks + .lock() + .take() + .expect("only we can drop the callbacks"); + for (_, callbacks) in callbacks { + for callback in callbacks { + callback(Err(CompilationError::HostQuit)) + } + } }); Ok(Self { @@ -186,8 +203,12 @@ impl ByteCodeCompiler { return Err(CompilationError::SourceNotFound); } - self.pending_callbacks - .lock() + let mut lock = self.pending_callbacks.lock(); + let Some(callbacks) = lock.as_mut() else { + return Err(CompilationError::HostQuit); + }; + + callbacks .entry(source_path.to_path_buf()) .or_default() .push(Box::new(callback));