diff --git a/src/backup/data_blob.rs b/src/backup/data_blob.rs index 144dc808..fc8c9bc2 100644 --- a/src/backup/data_blob.rs +++ b/src/backup/data_blob.rs @@ -307,7 +307,7 @@ impl DataBlob { } -use std::io::{Read, BufRead, Write, Seek, SeekFrom}; +use std::io::{Read, BufRead, BufReader, Write, Seek, SeekFrom}; struct CryptWriter { writer: W, @@ -647,41 +647,114 @@ impl <'a, W: Write + Seek> Write for DataBlobWriter<'a, W> { } } -/// Read compressed data blobs -pub struct CompressedDataBlobReader { - decompr: zstd::stream::read::Decoder, - hasher: Option, - expected_crc: u32, +struct ChecksumReader<'a, R> { + reader: R, + hasher: crc32fast::Hasher, + signer: Option>, } -impl CompressedDataBlobReader { +impl <'a, R: Read> ChecksumReader<'a, R> { - pub fn new(mut reader: R) -> Result { + fn new(reader: R, signer: Option>) -> Self { + let hasher = crc32fast::Hasher::new(); + Self { reader, hasher, signer } + } - let head: DataBlobHeader = unsafe { reader.read_le_value()? }; - if head.magic != COMPRESSED_BLOB_MAGIC_1_0 { - bail!("got wrong magic number"); + pub fn finish(mut self) -> Result<(R, u32, Option<[u8; 32]>), Error> { + let crc = self.hasher.finalize(); + + if let Some(ref mut signer) = self.signer { + let mut tag = [0u8; 32]; + signer.sign(&mut tag)?; + Ok((self.reader, crc, Some(tag))) + } else { + Ok((self.reader, crc, None)) } - let expected_crc = u32::from_le_bytes(head.crc); - let decompr = zstd::stream::read::Decoder::with_buffer(reader)?; - Ok(Self { decompr: decompr, hasher: Some(crc32fast::Hasher::new()), expected_crc }) } } -impl Read for CompressedDataBlobReader { +impl <'a, R: Read> Read for ChecksumReader<'a, R> { fn read(&mut self, buf: &mut [u8]) -> Result { - let count = self.decompr.read(buf)?; - if count == 0 { // EOF, verify crc - let hasher = self.hasher.take().expect("blob reader already finished"); - let crc = hasher.finalize(); - if crc != self.expected_crc { - return Err(std::io::Error::new(std::io::ErrorKind::Other, "blob reader crc error")); + let count = self.reader.read(buf)?; + if count > 0 { + self.hasher.update(&buf[..count]); + if let Some(ref mut signer) = self.signer { + signer.update(&buf[..count]) + .map_err(|err| { + std::io::Error::new( + std::io::ErrorKind::Other, + format!("hmac update failed - {}", err)) + })?; } - } else { - let hasher = self.hasher.as_mut().expect("blob reader already finished"); - hasher.update(buf); } Ok(count) } } + +enum BlobReaderState<'a, R: Read> { + Uncompressed { expected_crc: u32, csum_reader: ChecksumReader<'a, R> }, + Compressed { expected_crc: u32, decompr: zstd::stream::read::Decoder>> }, +} + +/// Read data blobs +pub struct DataBlobReader<'a, R: Read> { + state: BlobReaderState<'a, R>, +} + +impl <'a, R: Read> DataBlobReader<'a, R> { + + pub fn new(mut reader: R) -> Result { + + let head: DataBlobHeader = unsafe { reader.read_le_value()? }; + match head.magic { + UNCOMPRESSED_BLOB_MAGIC_1_0 => { + let expected_crc = u32::from_le_bytes(head.crc); + let csum_reader = ChecksumReader::new(reader, None); + Ok(Self { state: BlobReaderState::Uncompressed { expected_crc, csum_reader }}) + } + COMPRESSED_BLOB_MAGIC_1_0 => { + let expected_crc = u32::from_le_bytes(head.crc); + let csum_reader = ChecksumReader::new(reader, None); + + let decompr = zstd::stream::read::Decoder::new(csum_reader)?; + Ok(Self { state: BlobReaderState::Compressed { expected_crc, decompr }}) + } + _ => bail!("got wrong magic number {:?}", head.magic) + } + } + + pub fn finish(self) -> Result { + match self.state { + BlobReaderState::Uncompressed { csum_reader, expected_crc } => { + let (reader, crc, _) = csum_reader.finish()?; + if crc != expected_crc { + bail!("blob crc check failed"); + } + Ok(reader) + } + BlobReaderState::Compressed { expected_crc, decompr } => { + let csum_reader = decompr.finish().into_inner(); + let (reader, crc, _) = csum_reader.finish()?; + if crc != expected_crc { + bail!("blob crc check failed"); + } + Ok(reader) + } + } + } +} + +impl <'a, R: BufRead> Read for DataBlobReader<'a, R> { + + fn read(&mut self, buf: &mut [u8]) -> Result { + match &mut self.state { + BlobReaderState::Uncompressed { csum_reader, .. } => { + csum_reader.read(buf) + } + BlobReaderState::Compressed { decompr, .. } => { + decompr.read(buf) + } + } + } +}