Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 56 additions & 8 deletions rustls-test/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -969,19 +969,67 @@ pub fn server_name(name: &'static str) -> ServerName<'static> {
name.try_into().unwrap()
}

pub struct FailsReads {
errkind: io::ErrorKind,
/// An object that impls `io::Read` and `io::Write` for testing.
///
/// The `reads` and `writes` fields set the behaviour of these trait
/// implementations. They return the `WouldBlock` error if not otherwise
/// configured -- `TestNonBlockIo::default()` does this permanently.
///
/// This object panics on drop if the configured expected reads/writes
/// didn't take place.
#[derive(Debug, Default)]
pub struct TestNonBlockIo {
/// Each `write()` call is satisfied by inspecting this field.
///
/// If it is empty, `WouldBlock` is returned. Otherwise the write is
/// satisfied by popping a value and returning it (reduced by the size
/// of the write buffer, if needed).
pub writes: Vec<usize>,

/// Each `read()` call is satisfied by inspecting this field.
///
/// If it is empty, `WouldBlock` is returned. Otherwise the read is
/// satisfied by popping a value and copying it into the output
/// buffer. Each value must be no longer than the buffer for that
/// call.
pub reads: Vec<Vec<u8>>,
}

impl FailsReads {
pub fn new(errkind: io::ErrorKind) -> Self {
Self { errkind }
impl io::Read for TestNonBlockIo {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
println!("read {:?}", buf.len());
match self.reads.pop() {
None => Err(io::ErrorKind::WouldBlock.into()),
Some(data) => {
assert!(data.len() <= buf.len());
let take = core::cmp::min(data.len(), buf.len());
buf[..take].clone_from_slice(&data[..take]);
Ok(take)
}
}
}
}

impl io::Write for TestNonBlockIo {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
println!("write {:?}", buf.len());
match self.writes.pop() {
None => Err(io::ErrorKind::WouldBlock.into()),
Some(n) => Ok(core::cmp::min(n, buf.len())),
}
}

fn flush(&mut self) -> io::Result<()> {
println!("flush");
Ok(())
}
}

impl io::Read for FailsReads {
fn read(&mut self, _b: &mut [u8]) -> io::Result<usize> {
Err(io::Error::from(self.errkind))
impl Drop for TestNonBlockIo {
fn drop(&mut self) {
// ensure the object was exhausted as expected
assert!(self.reads.is_empty());
assert!(self.writes.is_empty());
}
}

Expand Down
6 changes: 2 additions & 4 deletions rustls/benches/benchmarks.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,17 @@
#![cfg(feature = "ring")]
#![allow(clippy::disallowed_types)]

use std::io;
use std::sync::Arc;

use bencher::{Bencher, benchmark_group, benchmark_main};
use rustls::ServerConnection;
use rustls::crypto::ring as provider;
use rustls_test::{FailsReads, KeyType, make_server_config};
use rustls_test::{KeyType, TestNonBlockIo, make_server_config};

fn bench_ewouldblock(c: &mut Bencher) {
let server_config = make_server_config(KeyType::Rsa2048, &provider::default_provider());
let mut server = ServerConnection::new(Arc::new(server_config)).unwrap();
let mut read_ewouldblock = FailsReads::new(io::ErrorKind::WouldBlock);
c.iter(|| server.read_tls(&mut read_ewouldblock));
c.iter(|| server.read_tls(&mut TestNonBlockIo::default()));
}

benchmark_group!(benches, bench_ewouldblock);
Expand Down
20 changes: 19 additions & 1 deletion rustls/src/conn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -603,6 +603,15 @@ impl<Data> ConnectionCommon<Data> {
return Ok((rdlen, wrlen));
}

// If we want to write, but are WouldBlocked by the underlying IO, *and*
// have no desire to read; that is everything.
if let (Some(_), false) = (&blocked_write, self.wants_read()) {
return match wrlen {
0 => Err(blocked_write.unwrap()),
_ => Ok((rdlen, wrlen)),
};
}

while !eof && self.wants_read() {
let read_size = match self.read_tls(io) {
Ok(0) => {
Expand Down Expand Up @@ -633,6 +642,15 @@ impl<Data> ConnectionCommon<Data> {
return Err(io::Error::new(io::ErrorKind::InvalidData, e));
};

// If we want to read, but are WouldBlocked by the underlying IO, *and*
// have no desire to write; that is everything.
if let (Some(_), false) = (&blocked_read, self.wants_write()) {
return match rdlen {
0 => Err(blocked_read.unwrap()),
_ => Ok((rdlen, wrlen)),
};
}

// if we're doing IO until handshaked, and we believe we've finished handshaking,
// but process_new_packets() has queued TLS data to send, loop around again to write
// the queued messages.
Expand All @@ -643,9 +661,9 @@ impl<Data> ConnectionCommon<Data> {
let blocked = blocked_write.zip(blocked_read);
match (eof, until_handshaked, self.is_handshaking(), blocked) {
(_, true, false, _) => return Ok((rdlen, wrlen)),
(_, _, _, Some((e, _))) if rdlen == 0 && wrlen == 0 => return Err(e),
(_, false, _, _) => return Ok((rdlen, wrlen)),
(true, true, true, _) => return Err(io::Error::from(io::ErrorKind::UnexpectedEof)),
(_, _, _, Some((e, _))) => return Err(e),
_ => {}
}
}
Expand Down
93 changes: 93 additions & 0 deletions rustls/tests/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2621,6 +2621,99 @@ fn client_complete_io_for_write() {
}
}

#[test]
fn client_complete_io_with_nonblocking_io() {
let (mut client, _) = make_pair(KeyType::Rsa2048, &provider::default_provider());

// absolutely no progress writing ClientHello
assert_eq!(
client
.complete_io(&mut TestNonBlockIo::default())
.unwrap_err()
.kind(),
io::ErrorKind::WouldBlock
);

// a little progress writing ClientHello
let (mut client, _) = make_pair(KeyType::Rsa2048, &provider::default_provider());
assert_eq!(
client
.complete_io(&mut TestNonBlockIo {
writes: vec![1],
reads: vec![],
})
.unwrap(),
(0, 1)
);

// complete writing ClientHello
let (mut client, _) = make_pair(KeyType::Rsa2048, &provider::default_provider());
assert_eq!(
client
.complete_io(&mut TestNonBlockIo {
writes: vec![4096],
reads: vec![],
})
.unwrap_err()
.kind(),
io::ErrorKind::WouldBlock
);

// complete writing ClientHello, partial read of ServerHello
let (mut client, _) = make_pair(KeyType::Rsa2048, &provider::default_provider());
let (rd, wr) = dbg!(client.complete_io(&mut TestNonBlockIo {
writes: vec![4096],
reads: vec![vec![ContentType::Handshake.into()]],
}))
.unwrap();
assert_eq!(rd, 1);
assert!(wr > 1);

// data phase:
let (mut client, mut server) = make_pair(KeyType::Rsa2048, &provider::default_provider());
do_handshake(&mut client, &mut server);

// read
assert_eq!(
client
.complete_io(&mut TestNonBlockIo {
reads: vec![vec![ContentType::ApplicationData.into()]],
writes: vec![],
})
.unwrap(),
(1, 0)
);

// write
client
.writer()
.write_all(b"hello")
.unwrap();

// no progress
assert_eq!(
client
.complete_io(&mut TestNonBlockIo {
reads: vec![],
writes: vec![],
})
.unwrap_err()
.kind(),
io::ErrorKind::WouldBlock
);

// some write progress
assert_eq!(
client
.complete_io(&mut TestNonBlockIo {
reads: vec![],
writes: vec![1],
})
.unwrap(),
(0, 1)
);
}

#[test]
fn buffered_client_complete_io_for_write() {
let provider = provider::default_provider();
Expand Down
Loading