diff --git a/crates/stdlib/src/ssl.rs b/crates/stdlib/src/ssl.rs index 31973a84f89..6c62a24397b 100644 --- a/crates/stdlib/src/ssl.rs +++ b/crates/stdlib/src/ssl.rs @@ -50,7 +50,7 @@ mod _ssl { // Import error types used in this module (others are exposed via pymodule(with(...))) use super::error::{ - PySSLEOFError, PySSLError, create_ssl_want_read_error, create_ssl_want_write_error, + PySSLError, create_ssl_eof_error, create_ssl_want_read_error, create_ssl_want_write_error, }; use alloc::sync::Arc; use core::{ @@ -1903,6 +1903,7 @@ mod _ssl { client_hello_buffer: PyMutex::new(None), shutdown_state: PyMutex::new(ShutdownState::NotStarted), pending_tls_output: PyMutex::new(Vec::new()), + write_buffered_len: PyMutex::new(0), deferred_cert_error: Arc::new(ParkingRwLock::new(None)), }; @@ -1974,6 +1975,7 @@ mod _ssl { client_hello_buffer: PyMutex::new(None), shutdown_state: PyMutex::new(ShutdownState::NotStarted), pending_tls_output: PyMutex::new(Vec::new()), + write_buffered_len: PyMutex::new(0), deferred_cert_error: Arc::new(ParkingRwLock::new(None)), }; @@ -2345,6 +2347,10 @@ mod _ssl { // but the socket cannot accept all the data immediately #[pytraverse(skip)] pub(crate) pending_tls_output: PyMutex>, + // Tracks bytes already buffered in rustls for the current write operation + // Prevents duplicate writes when retrying after WantWrite/WantRead + #[pytraverse(skip)] + pub(crate) write_buffered_len: PyMutex, // Deferred client certificate verification error (for TLS 1.3) // Stores error message if client cert verification failed during handshake // Error is raised on first I/O operation after handshake @@ -2604,6 +2610,36 @@ mod _ssl { Ok(timed_out) } + // Internal implementation with explicit timeout override + pub(crate) fn sock_wait_for_io_with_timeout( + &self, + kind: SelectKind, + timeout: Option, + vm: &VirtualMachine, + ) -> PyResult { + if self.is_bio_mode() { + // BIO mode doesn't use select + return Ok(false); + } + + if let Some(t) = timeout + && t.is_zero() + { + // Non-blocking mode - don't use select + return Ok(false); + } + + let py_socket: PyRef = self.sock.clone().try_into_value(vm)?; + let socket = py_socket + .sock() + .map_err(|e| vm.new_os_error(format!("Failed to get socket: {e}")))?; + + let timed_out = sock_select(&socket, kind, timeout) + .map_err(|e| vm.new_os_error(format!("select failed: {e}")))?; + + Ok(timed_out) + } + // SNI (Server Name Indication) Helper Methods: // These methods support the server-side handshake SNI callback mechanism @@ -2783,6 +2819,7 @@ mod _ssl { let is_non_blocking = socket_timeout.map(|t| t.is_zero()).unwrap_or(false); let mut sent_total = 0; + while sent_total < pending.len() { // Calculate timeout: use deadline if provided, otherwise use socket timeout let timeout_to_use = if let Some(dl) = deadline { @@ -2810,6 +2847,9 @@ mod _ssl { if timed_out { // Keep unsent data in pending buffer *pending = pending[sent_total..].to_vec(); + if is_non_blocking { + return Err(create_ssl_want_write_error(vm).upcast()); + } return Err( timeout_error_msg(vm, "The write operation timed out".to_string()).upcast(), ); @@ -2824,6 +2864,7 @@ mod _ssl { *pending = pending[sent_total..].to_vec(); return Err(create_ssl_want_write_error(vm).upcast()); } + // Socket said ready but sent 0 bytes - retry continue; } sent_total += sent; @@ -2916,6 +2957,9 @@ mod _ssl { pub(crate) fn blocking_flush_all_pending(&self, vm: &VirtualMachine) -> PyResult<()> { // Get socket timeout to respect during flush let timeout = self.get_socket_timeout(vm)?; + if timeout.map(|t| t.is_zero()).unwrap_or(false) { + return self.flush_pending_tls_output(vm, None); + } loop { let pending_data = { @@ -2948,8 +2992,7 @@ mod _ssl { let mut pending = self.pending_tls_output.lock(); pending.drain(..sent); } - // If sent == 0, socket wasn't ready despite select() saying so - // Continue loop to retry - this avoids infinite loops + // If sent == 0, loop will retry with sock_select } Err(e) => { if is_blocking_io_error(&e, vm) { @@ -3515,16 +3558,60 @@ mod _ssl { return_data(buf, &buffer, vm) } Err(crate::ssl::compat::SslError::Eof) => { + // If plaintext is still buffered, return it before EOF. + let pending = { + let mut conn_guard = self.connection.lock(); + let conn = match conn_guard.as_mut() { + Some(conn) => conn, + None => return Err(create_ssl_eof_error(vm).upcast()), + }; + use std::io::BufRead; + let mut reader = conn.reader(); + reader.fill_buf().map(|buf| buf.len()).unwrap_or(0) + }; + if pending > 0 { + let mut buf = vec![0u8; pending.min(len)]; + let read_retry = { + let mut conn_guard = self.connection.lock(); + let conn = conn_guard + .as_mut() + .ok_or_else(|| vm.new_value_error("Connection not established"))?; + crate::ssl::compat::ssl_read(conn, &mut buf, self, vm) + }; + if let Ok(n) = read_retry { + buf.truncate(n); + return return_data(buf, &buffer, vm); + } + } // EOF occurred in violation of protocol (unexpected closure) - Err(vm - .new_os_subtype_error( - PySSLEOFError::class(&vm.ctx).to_owned(), - None, - "EOF occurred in violation of protocol", - ) - .upcast()) + Err(create_ssl_eof_error(vm).upcast()) } Err(crate::ssl::compat::SslError::ZeroReturn) => { + // If plaintext is still buffered, return it before clean EOF. + let pending = { + let mut conn_guard = self.connection.lock(); + let conn = match conn_guard.as_mut() { + Some(conn) => conn, + None => return return_data(vec![], &buffer, vm), + }; + use std::io::BufRead; + let mut reader = conn.reader(); + reader.fill_buf().map(|buf| buf.len()).unwrap_or(0) + }; + if pending > 0 { + let mut buf = vec![0u8; pending.min(len)]; + let read_retry = { + let mut conn_guard = self.connection.lock(); + let conn = conn_guard + .as_mut() + .ok_or_else(|| vm.new_value_error("Connection not established"))?; + crate::ssl::compat::ssl_read(conn, &mut buf, self, vm) + }; + if let Ok(n) = read_retry { + buf.truncate(n); + return return_data(buf, &buffer, vm); + } + } // Clean closure with close_notify - return empty data return_data(vec![], &buffer, vm) } @@ -3580,21 +3667,17 @@ mod _ssl { let data_bytes = data.borrow_buf(); let data_len = data_bytes.len(); - // return 0 immediately for empty write if data_len == 0 { return Ok(0); } - // Ensure handshake is done - if not, complete it first - // This matches OpenSSL behavior where SSL_write() auto-completes handshake + // Ensure handshake is done (SSL_write auto-completes handshake) if !*self.handshake_done.lock() { self.do_handshake(vm)?; } - // Check if connection has been shut down - // After unwrap()/shutdown(), write operations should fail with SSLError - let shutdown_state = *self.shutdown_state.lock(); - if shutdown_state != ShutdownState::NotStarted { + // Check shutdown state + if *self.shutdown_state.lock() != ShutdownState::NotStarted { return Err(vm .new_os_subtype_error( PySSLError::class(&vm.ctx).to_owned(), @@ -3604,76 +3687,32 @@ mod _ssl { .upcast()); } - { + // Call ssl_write (matches CPython's SSL_write_ex loop) + let result = { let mut conn_guard = self.connection.lock(); let conn = conn_guard .as_mut() .ok_or_else(|| vm.new_value_error("Connection not established"))?; - let is_bio = self.is_bio_mode(); - let data: &[u8] = data_bytes.as_ref(); + crate::ssl::compat::ssl_write(conn, data_bytes.as_ref(), self, vm) + }; - // CRITICAL: Flush any pending TLS data before writing new data - // This ensures TLS 1.3 Finished message reaches server before application data - // Without this, server may not be ready to process our data - if !is_bio { - self.flush_pending_tls_output(vm, None)?; + match result { + Ok(n) => { + self.check_deferred_cert_error(vm)?; + Ok(n) } - - // Write data in chunks to avoid filling the internal TLS buffer - // rustls has a limited internal buffer, so we need to flush periodically - const CHUNK_SIZE: usize = 16384; // 16KB chunks (typical TLS record size) - let mut written = 0; - - while written < data.len() { - let chunk_end = core::cmp::min(written + CHUNK_SIZE, data.len()); - let chunk = &data[written..chunk_end]; - - // Write chunk to TLS layer - { - let mut writer = conn.writer(); - use std::io::Write; - writer - .write_all(chunk) - .map_err(|e| vm.new_os_error(format!("Write failed: {e}")))?; - // Flush to ensure data is converted to TLS records - writer - .flush() - .map_err(|e| vm.new_os_error(format!("Flush failed: {e}")))?; - } - - written = chunk_end; - - // Flush TLS data to socket after each chunk - if conn.wants_write() { - if is_bio { - self.write_pending_tls(conn, vm)?; - } else { - // Socket mode: flush all pending TLS data - // First, try to send any previously pending data - self.flush_pending_tls_output(vm, None)?; - - while conn.wants_write() { - let mut buf = Vec::new(); - conn.write_tls(&mut buf).map_err(|e| { - vm.new_os_error(format!("TLS write failed: {e}")) - })?; - - if !buf.is_empty() { - // Try to send TLS data, saving unsent bytes to pending buffer - self.send_tls_output(buf, vm)?; - } - } - } - } + Err(crate::ssl::compat::SslError::WantRead) => { + Err(create_ssl_want_read_error(vm).upcast()) + } + Err(crate::ssl::compat::SslError::WantWrite) => { + Err(create_ssl_want_write_error(vm).upcast()) + } + Err(crate::ssl::compat::SslError::Timeout(msg)) => { + Err(timeout_error_msg(vm, msg).upcast()) } + Err(e) => Err(e.into_py_err(vm)), } - - // Check for deferred certificate verification errors (TLS 1.3) - // Must be checked AFTER write completes, as the error may be set during I/O - self.check_deferred_cert_error(vm)?; - - Ok(data_len) } #[pymethod] @@ -4013,6 +4052,10 @@ mod _ssl { // Write close_notify to outgoing buffer/BIO self.write_pending_tls(conn, vm)?; + // Ensure close_notify and any pending TLS data are flushed + if !is_bio { + self.flush_pending_tls_output(vm, None)?; + } // Update state *self.shutdown_state.lock() = ShutdownState::SentCloseNotify; diff --git a/crates/stdlib/src/ssl/compat.rs b/crates/stdlib/src/ssl/compat.rs index 3c72ccf4e21..322fdde5b9a 100644 --- a/crates/stdlib/src/ssl/compat.rs +++ b/crates/stdlib/src/ssl/compat.rs @@ -36,8 +36,8 @@ use super::_ssl::PySSLSocket; // Import error types and helper functions from error module use super::error::{ - PySSLCertVerificationError, PySSLError, create_ssl_eof_error, create_ssl_want_read_error, - create_ssl_want_write_error, create_ssl_zero_return_error, + PySSLCertVerificationError, PySSLError, create_ssl_eof_error, create_ssl_syscall_error, + create_ssl_want_read_error, create_ssl_want_write_error, create_ssl_zero_return_error, }; // SSL Verification Flags @@ -553,8 +553,8 @@ impl SslError { SslError::WantWrite => create_ssl_want_write_error(vm).upcast(), SslError::Timeout(msg) => timeout_error_msg(vm, msg).upcast(), SslError::Syscall(msg) => { - // Create SSLError with library=None for syscall errors during SSL operations - Self::create_ssl_error_with_reason(vm, None, &msg, msg.clone()) + // SSLSyscallError with errno=SSL_ERROR_SYSCALL (5) + create_ssl_syscall_error(vm, msg).upcast() } SslError::Ssl(msg) => vm .new_os_subtype_error( @@ -1039,6 +1039,36 @@ fn send_all_bytes( return Err(SslError::Timeout("The operation timed out".to_string())); } + // Wait for socket to be writable before sending + let timed_out = if let Some(dl) = deadline { + let now = std::time::Instant::now(); + if now >= dl { + socket + .pending_tls_output + .lock() + .extend_from_slice(&buf[sent_total..]); + return Err(SslError::Timeout( + "The write operation timed out".to_string(), + )); + } + socket + .sock_wait_for_io_with_timeout(SelectKind::Write, Some(dl - now), vm) + .map_err(SslError::Py)? + } else { + socket + .sock_wait_for_io_impl(SelectKind::Write, vm) + .map_err(SslError::Py)? + }; + if timed_out { + socket + .pending_tls_output + .lock() + .extend_from_slice(&buf[sent_total..]); + return Err(SslError::Timeout( + "The write operation timed out".to_string(), + )); + } + match socket.sock_send(&buf[sent_total..], vm) { Ok(result) => { let sent: usize = result @@ -1443,9 +1473,17 @@ pub(super) fn ssl_do_handshake( } } - // If we exit the loop without completing handshake, return error - // Check rustls state to provide better error message + // If we exit the loop without completing handshake, return appropriate error if conn.is_handshaking() { + // For non-blocking sockets, return WantRead/WantWrite to signal caller + // should retry when socket is ready. This matches OpenSSL behavior. + if conn.wants_write() { + return Err(SslError::WantWrite); + } + if conn.wants_read() { + return Err(SslError::WantRead); + } + // Neither wants_read nor wants_write - this is a real error Err(SslError::Syscall(format!( "SSL handshake failed: incomplete after {iteration_count} iterations", ))) @@ -1581,6 +1619,14 @@ pub(super) fn ssl_read( if let Some(t) = timeout && t.is_zero() { + // Non-blocking socket: check if peer has closed before returning WantRead + // If close_notify was received, we should return ZeroReturn (EOF), not WantRead + // This is critical for asyncore-based applications that rely on recv() returning + // 0 or raising SSL_ERROR_ZERO_RETURN to detect connection close. + let io_state = conn.process_new_packets().map_err(SslError::from_rustls)?; + if io_state.peer_has_closed() { + return Err(SslError::ZeroReturn); + } // Non-blocking socket: return immediately return Err(SslError::WantRead); } @@ -1605,7 +1651,13 @@ pub(super) fn ssl_read( .unwrap_or(0); if bytes_read == 0 { - // No more data available - connection might be closed + // No more data available - check if this is clean shutdown or unexpected EOF + // If close_notify was already received, return ZeroReturn (clean closure) + // Otherwise, return Eof (unexpected EOF) + let io_state = conn.process_new_packets().map_err(SslError::from_rustls)?; + if io_state.peer_has_closed() { + return Err(SslError::ZeroReturn); + } return Err(SslError::Eof); } @@ -1648,6 +1700,138 @@ pub(super) fn ssl_read( } } +/// Equivalent to OpenSSL's SSL_write() +/// +/// Writes application data to TLS connection. +/// Automatically handles TLS record I/O as needed. +/// +/// = SSL_write_ex() +pub(super) fn ssl_write( + conn: &mut TlsConnection, + data: &[u8], + socket: &PySSLSocket, + vm: &VirtualMachine, +) -> SslResult { + if data.is_empty() { + return Ok(0); + } + + let is_bio = socket.is_bio_mode(); + + // Get socket timeout and calculate deadline (= _PyDeadline_Init) + let deadline = if !is_bio { + match socket.get_socket_timeout(vm).map_err(SslError::Py)? { + Some(timeout) if !timeout.is_zero() => Some(std::time::Instant::now() + timeout), + _ => None, + } + } else { + None + }; + + // Flush any pending TLS output before writing new data + if !is_bio { + socket + .flush_pending_tls_output(vm, deadline) + .map_err(SslError::Py)?; + } + + // Check if we already have data buffered from a previous retry + // (prevents duplicate writes when retrying after WantWrite/WantRead) + let already_buffered = *socket.write_buffered_len.lock(); + + // Only write plaintext if not already buffered + if already_buffered == 0 { + // Write plaintext to rustls (= SSL_write_ex internal buffer write) + { + let mut writer = conn.writer(); + use std::io::Write; + writer + .write_all(data) + .map_err(|e| SslError::Syscall(format!("Write failed: {e}")))?; + } + // Mark data as buffered + *socket.write_buffered_len.lock() = data.len(); + } else if already_buffered != data.len() { + // Caller is retrying with different data - this is a protocol error + // Clear the buffer state and return an SSL error (bad write retry) + *socket.write_buffered_len.lock() = 0; + return Err(SslError::Ssl("bad write retry".to_string())); + } + // else: already_buffered == data.len(), this is a valid retry + + // Loop to send TLS records, handling WANT_READ/WANT_WRITE + // Matches CPython's do-while loop on SSL_ERROR_WANT_READ/WANT_WRITE + loop { + // Check deadline + if let Some(dl) = deadline + && std::time::Instant::now() >= dl + { + return Err(SslError::Timeout( + "The write operation timed out".to_string(), + )); + } + + // Check if rustls has TLS data to send + if !conn.wants_write() { + // All TLS data sent successfully + break; + } + + // Get TLS records from rustls + let tls_data = ssl_write_tls_records(conn)?; + if tls_data.is_empty() { + break; + } + + // Send TLS data to socket + match send_all_bytes(socket, tls_data, vm, deadline) { + Ok(()) => { + // Successfully sent, continue loop to check for more data + } + Err(SslError::WantWrite) => { + // Non-blocking socket would block - return WANT_WRITE + // Keep write_buffered_len set so we don't re-buffer on retry + return Err(SslError::WantWrite); + } + Err(SslError::WantRead) => { + // Need to read before write can complete (e.g., renegotiation) + // This matches CPython's handling of SSL_ERROR_WANT_READ in write + if is_bio { + // Keep write_buffered_len set so we don't re-buffer on retry + return Err(SslError::WantRead); + } + // For socket mode, try to read TLS data + let recv_result = socket.sock_recv(4096, vm).map_err(SslError::Py)?; + ssl_read_tls_records(conn, recv_result, false, vm)?; + conn.process_new_packets().map_err(SslError::from_rustls)?; + // Continue loop + } + Err(e @ SslError::Timeout(_)) => { + // Preserve buffered state so retry doesn't duplicate data + // (send_all_bytes saved unsent TLS bytes to pending_tls_output) + return Err(e); + } + Err(e) => { + // Clear buffer state on error + *socket.write_buffered_len.lock() = 0; + return Err(e); + } + } + } + + // Final flush to ensure all data is sent + if !is_bio { + socket + .flush_pending_tls_output(vm, deadline) + .map_err(SslError::Py)?; + } + + // Write completed successfully - clear buffer state + *socket.write_buffered_len.lock() = 0; + + Ok(data.len()) +} + // Helper functions (private-ish, used by public SSL functions) /// Write TLS records from rustls to socket @@ -1684,26 +1868,24 @@ fn ssl_read_tls_records( // 1. Clean shutdown: received TLS close_notify → return ZeroReturn (0 bytes) // 2. Unexpected EOF: no close_notify → return Eof (SSLEOFError) // - // SSL_ERROR_ZERO_RETURN vs SSL_ERROR_SYSCALL(errno=0) logic + // SSL_ERROR_ZERO_RETURN vs SSL_ERROR_EOF logic // CPython checks SSL_get_shutdown() & SSL_RECEIVED_SHUTDOWN // // Process any buffered TLS records (may contain close_notify) - let _ = conn.process_new_packets(); - - // IMPORTANT: CPython's default behavior (suppress_ragged_eofs=True) - // treats empty recv() as clean shutdown, returning 0 bytes instead of raising SSLEOFError. - // - // This is necessary for HTTP/1.0 servers that: - // 1. Send response without Content-Length header - // 2. Signal end-of-response by closing connection (TCP FIN) - // 3. Don't send TLS close_notify before TCP close - // - // While this could theoretically allow truncation attacks, - // it's the standard behavior for compatibility with real-world servers. - // Python only raises SSLEOFError when suppress_ragged_eofs=False is explicitly set. - // - // TODO: Implement suppress_ragged_eofs parameter if needed for strict security mode. - return Err(SslError::ZeroReturn); + match conn.process_new_packets() { + Ok(io_state) => { + if io_state.peer_has_closed() { + // Received close_notify - normal SSL closure (SSL_ERROR_ZERO_RETURN) + return Err(SslError::ZeroReturn); + } else { + // No close_notify - ragged EOF (SSL_ERROR_EOF → SSLEOFError) + // CPython raises SSLEOFError here, which SSLSocket.read() handles + // based on suppress_ragged_eofs setting + return Err(SslError::Eof); + } + } + Err(e) => return Err(SslError::from_rustls(e)), + } } } @@ -1816,6 +1998,9 @@ fn ssl_ensure_data_available( let data = match socket.sock_recv(2048, vm) { Ok(data) => data, Err(e) => { + if is_blocking_io_error(&e, vm) { + return Err(SslError::WantRead); + } // Before returning socket error, check if rustls already has a queued TLS alert // This mirrors CPython/OpenSSL behavior: SSL errors take precedence over socket errors // On Windows, TCP RST may arrive before we read the alert, but rustls may have diff --git a/crates/stdlib/src/ssl/error.rs b/crates/stdlib/src/ssl/error.rs index 6219eff41b5..cbc59e0e8f6 100644 --- a/crates/stdlib/src/ssl/error.rs +++ b/crates/stdlib/src/ssl/error.rs @@ -132,4 +132,15 @@ pub(crate) mod ssl_error { "TLS/SSL connection has been closed (EOF)", ) } + + pub fn create_ssl_syscall_error( + vm: &VirtualMachine, + msg: impl Into, + ) -> PyRef { + vm.new_os_subtype_error( + PySSLSyscallError::class(&vm.ctx).to_owned(), + Some(SSL_ERROR_SYSCALL), + msg.into(), + ) + } } diff --git a/crates/vm/src/stdlib/thread.rs b/crates/vm/src/stdlib/thread.rs index 9f0c0535d71..f7e47b15deb 100644 --- a/crates/vm/src/stdlib/thread.rs +++ b/crates/vm/src/stdlib/thread.rs @@ -421,14 +421,14 @@ pub(crate) mod _thread { vm.new_thread() .make_spawn_func(move |vm| run_thread(func, args, vm)), ) - .map(|handle| { - vm.state.thread_count.fetch_add(1); - thread_to_id(&handle) - }) + .map(|handle| thread_to_id(&handle)) .map_err(|err| vm.new_runtime_error(format!("can't start new thread: {err}"))) } fn run_thread(func: ArgCallable, args: FuncArgs, vm: &VirtualMachine) { + // Increment thread count when thread actually starts executing + vm.state.thread_count.fetch_add(1); + match func.invoke(args, vm) { Ok(_obj) => {} Err(e) if e.fast_isinstance(vm.ctx.exceptions.system_exit) => {} @@ -1168,13 +1168,6 @@ pub(crate) mod _thread { // Mark as done inner_for_cleanup.lock().state = ThreadHandleState::Done; - // Signal waiting threads that this thread is done - { - let (lock, cvar) = &*done_event_for_cleanup; - *lock.lock() = true; - cvar.notify_all(); - } - // Handle sentinels for lock in SENTINELS.take() { if lock.mu.is_locked() { @@ -1189,8 +1182,19 @@ pub(crate) mod _thread { crate::vm::thread::cleanup_current_thread_frames(vm); vm_state.thread_count.fetch_sub(1); + + // Signal waiting threads that this thread is done + // This must be LAST to ensure all cleanup is complete before join() returns + { + let (lock, cvar) = &*done_event_for_cleanup; + *lock.lock() = true; + cvar.notify_all(); + } } + // Increment thread count when thread actually starts executing + vm_state.thread_count.fetch_add(1); + // Run the function match func.invoke((), vm) { Ok(_) => {} @@ -1206,8 +1210,6 @@ pub(crate) mod _thread { })) .map_err(|err| vm.new_runtime_error(format!("can't start new thread: {err}")))?; - vm.state.thread_count.fetch_add(1); - // Store the join handle handle.inner.lock().join_handle = Some(join_handle);