diff --git a/rustls-test/src/lib.rs b/rustls-test/src/lib.rs index 51819a4df58..9ea2e80c1e3 100644 --- a/rustls-test/src/lib.rs +++ b/rustls-test/src/lib.rs @@ -969,19 +969,67 @@ pub fn server_name(name: &'static str) -> ServerName<'static> { name.try_into().unwrap() } -pub struct FailsReads { - errkind: io::ErrorKind, +/// An object that impls `io::Read` and `io::Write` for testing. +/// +/// The `reads` and `writes` fields set the behaviour of these trait +/// implementations. They return the `WouldBlock` error if not otherwise +/// configured -- `TestNonBlockIo::default()` does this permanently. +/// +/// This object panics on drop if the configured expected reads/writes +/// didn't take place. +#[derive(Debug, Default)] +pub struct TestNonBlockIo { + /// Each `write()` call is satisfied by inspecting this field. + /// + /// If it is empty, `WouldBlock` is returned. Otherwise the write is + /// satisfied by popping a value and returning it (reduced by the size + /// of the write buffer, if needed). + pub writes: Vec, + + /// Each `read()` call is satisfied by inspecting this field. + /// + /// If it is empty, `WouldBlock` is returned. Otherwise the read is + /// satisfied by popping a value and copying it into the output + /// buffer. Each value must be no longer than the buffer for that + /// call. + pub reads: Vec>, } -impl FailsReads { - pub fn new(errkind: io::ErrorKind) -> Self { - Self { errkind } +impl io::Read for TestNonBlockIo { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + println!("read {:?}", buf.len()); + match self.reads.pop() { + None => Err(io::ErrorKind::WouldBlock.into()), + Some(data) => { + assert!(data.len() <= buf.len()); + let take = core::cmp::min(data.len(), buf.len()); + buf[..take].clone_from_slice(&data[..take]); + Ok(take) + } + } + } +} + +impl io::Write for TestNonBlockIo { + fn write(&mut self, buf: &[u8]) -> io::Result { + println!("write {:?}", buf.len()); + match self.writes.pop() { + None => Err(io::ErrorKind::WouldBlock.into()), + Some(n) => Ok(core::cmp::min(n, buf.len())), + } + } + + fn flush(&mut self) -> io::Result<()> { + println!("flush"); + Ok(()) } } -impl io::Read for FailsReads { - fn read(&mut self, _b: &mut [u8]) -> io::Result { - Err(io::Error::from(self.errkind)) +impl Drop for TestNonBlockIo { + fn drop(&mut self) { + // ensure the object was exhausted as expected + assert!(self.reads.is_empty()); + assert!(self.writes.is_empty()); } } diff --git a/rustls/benches/benchmarks.rs b/rustls/benches/benchmarks.rs index 5aebc995f0c..5fc473de872 100644 --- a/rustls/benches/benchmarks.rs +++ b/rustls/benches/benchmarks.rs @@ -1,19 +1,17 @@ #![cfg(feature = "ring")] #![allow(clippy::disallowed_types)] -use std::io; use std::sync::Arc; use bencher::{Bencher, benchmark_group, benchmark_main}; use rustls::ServerConnection; use rustls::crypto::ring as provider; -use rustls_test::{FailsReads, KeyType, make_server_config}; +use rustls_test::{KeyType, TestNonBlockIo, make_server_config}; fn bench_ewouldblock(c: &mut Bencher) { let server_config = make_server_config(KeyType::Rsa2048, &provider::default_provider()); let mut server = ServerConnection::new(Arc::new(server_config)).unwrap(); - let mut read_ewouldblock = FailsReads::new(io::ErrorKind::WouldBlock); - c.iter(|| server.read_tls(&mut read_ewouldblock)); + c.iter(|| server.read_tls(&mut TestNonBlockIo::default())); } benchmark_group!(benches, bench_ewouldblock); diff --git a/rustls/src/conn.rs b/rustls/src/conn.rs index 7bacd3f2472..25d573a8850 100644 --- a/rustls/src/conn.rs +++ b/rustls/src/conn.rs @@ -603,6 +603,15 @@ impl ConnectionCommon { return Ok((rdlen, wrlen)); } + // If we want to write, but are WouldBlocked by the underlying IO, *and* + // have no desire to read; that is everything. + if let (Some(_), false) = (&blocked_write, self.wants_read()) { + return match wrlen { + 0 => Err(blocked_write.unwrap()), + _ => Ok((rdlen, wrlen)), + }; + } + while !eof && self.wants_read() { let read_size = match self.read_tls(io) { Ok(0) => { @@ -633,6 +642,15 @@ impl ConnectionCommon { return Err(io::Error::new(io::ErrorKind::InvalidData, e)); }; + // If we want to read, but are WouldBlocked by the underlying IO, *and* + // have no desire to write; that is everything. + if let (Some(_), false) = (&blocked_read, self.wants_write()) { + return match rdlen { + 0 => Err(blocked_read.unwrap()), + _ => Ok((rdlen, wrlen)), + }; + } + // if we're doing IO until handshaked, and we believe we've finished handshaking, // but process_new_packets() has queued TLS data to send, loop around again to write // the queued messages. @@ -643,9 +661,9 @@ impl ConnectionCommon { let blocked = blocked_write.zip(blocked_read); match (eof, until_handshaked, self.is_handshaking(), blocked) { (_, true, false, _) => return Ok((rdlen, wrlen)), + (_, _, _, Some((e, _))) if rdlen == 0 && wrlen == 0 => return Err(e), (_, false, _, _) => return Ok((rdlen, wrlen)), (true, true, true, _) => return Err(io::Error::from(io::ErrorKind::UnexpectedEof)), - (_, _, _, Some((e, _))) => return Err(e), _ => {} } } diff --git a/rustls/tests/api.rs b/rustls/tests/api.rs index e87fd8cc812..006c8187af7 100644 --- a/rustls/tests/api.rs +++ b/rustls/tests/api.rs @@ -2621,6 +2621,99 @@ fn client_complete_io_for_write() { } } +#[test] +fn client_complete_io_with_nonblocking_io() { + let (mut client, _) = make_pair(KeyType::Rsa2048, &provider::default_provider()); + + // absolutely no progress writing ClientHello + assert_eq!( + client + .complete_io(&mut TestNonBlockIo::default()) + .unwrap_err() + .kind(), + io::ErrorKind::WouldBlock + ); + + // a little progress writing ClientHello + let (mut client, _) = make_pair(KeyType::Rsa2048, &provider::default_provider()); + assert_eq!( + client + .complete_io(&mut TestNonBlockIo { + writes: vec![1], + reads: vec![], + }) + .unwrap(), + (0, 1) + ); + + // complete writing ClientHello + let (mut client, _) = make_pair(KeyType::Rsa2048, &provider::default_provider()); + assert_eq!( + client + .complete_io(&mut TestNonBlockIo { + writes: vec![4096], + reads: vec![], + }) + .unwrap_err() + .kind(), + io::ErrorKind::WouldBlock + ); + + // complete writing ClientHello, partial read of ServerHello + let (mut client, _) = make_pair(KeyType::Rsa2048, &provider::default_provider()); + let (rd, wr) = dbg!(client.complete_io(&mut TestNonBlockIo { + writes: vec![4096], + reads: vec![vec![ContentType::Handshake.into()]], + })) + .unwrap(); + assert_eq!(rd, 1); + assert!(wr > 1); + + // data phase: + let (mut client, mut server) = make_pair(KeyType::Rsa2048, &provider::default_provider()); + do_handshake(&mut client, &mut server); + + // read + assert_eq!( + client + .complete_io(&mut TestNonBlockIo { + reads: vec![vec![ContentType::ApplicationData.into()]], + writes: vec![], + }) + .unwrap(), + (1, 0) + ); + + // write + client + .writer() + .write_all(b"hello") + .unwrap(); + + // no progress + assert_eq!( + client + .complete_io(&mut TestNonBlockIo { + reads: vec![], + writes: vec![], + }) + .unwrap_err() + .kind(), + io::ErrorKind::WouldBlock + ); + + // some write progress + assert_eq!( + client + .complete_io(&mut TestNonBlockIo { + reads: vec![], + writes: vec![1], + }) + .unwrap(), + (0, 1) + ); +} + #[test] fn buffered_client_complete_io_for_write() { let provider = provider::default_provider();