Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert some panics that happen on invalid parquet files to error results #6738

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions parquet/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

//! Common Parquet errors and macros.

use core::num::TryFromIntError;
use std::error::Error;
use std::{cell, io, result, str};

Expand Down Expand Up @@ -76,6 +77,12 @@ impl Error for ParquetError {
}
}

impl From<TryFromIntError> for ParquetError {
fn from(e: TryFromIntError) -> ParquetError {
ParquetError::General(format!("Integer overflow: {e}"))
}
}

impl From<io::Error> for ParquetError {
fn from(e: io::Error) -> ParquetError {
ParquetError::External(Box::new(e))
Expand Down
26 changes: 13 additions & 13 deletions parquet/src/file/metadata/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -617,7 +617,8 @@ impl ParquetMetaDataReader {
for rg in t_file_metadata.row_groups {
row_groups.push(RowGroupMetaData::from_thrift(schema_descr.clone(), rg)?);
}
let column_orders = Self::parse_column_orders(t_file_metadata.column_orders, &schema_descr);
let column_orders =
Self::parse_column_orders(t_file_metadata.column_orders, &schema_descr)?;

let file_metadata = FileMetaData::new(
t_file_metadata.version,
Expand All @@ -635,15 +636,13 @@ impl ParquetMetaDataReader {
fn parse_column_orders(
t_column_orders: Option<Vec<TColumnOrder>>,
schema_descr: &SchemaDescriptor,
) -> Option<Vec<ColumnOrder>> {
) -> Result<Option<Vec<ColumnOrder>>> {
match t_column_orders {
Some(orders) => {
// Should always be the case
assert_eq!(
orders.len(),
schema_descr.num_columns(),
"Column order length mismatch"
);
if orders.len() != schema_descr.num_columns() {
return Err(general_err!("Column order length mismatch"));
};
let mut res = Vec::new();
for (i, column) in schema_descr.columns().iter().enumerate() {
match orders[i] {
Expand All @@ -657,9 +656,9 @@ impl ParquetMetaDataReader {
}
}
}
Some(res)
Ok(Some(res))
}
None => None,
None => Ok(None),
}
}
}
Expand Down Expand Up @@ -731,7 +730,7 @@ mod tests {
]);

assert_eq!(
ParquetMetaDataReader::parse_column_orders(t_column_orders, &schema_descr),
ParquetMetaDataReader::parse_column_orders(t_column_orders, &schema_descr).unwrap(),
Some(vec![
ColumnOrder::TYPE_DEFINED_ORDER(SortOrder::SIGNED),
ColumnOrder::TYPE_DEFINED_ORDER(SortOrder::SIGNED)
Expand All @@ -740,20 +739,21 @@ mod tests {

// Test when no column orders are defined.
assert_eq!(
ParquetMetaDataReader::parse_column_orders(None, &schema_descr),
ParquetMetaDataReader::parse_column_orders(None, &schema_descr).unwrap(),
None
);
}

#[test]
#[should_panic(expected = "Column order length mismatch")]
fn test_metadata_column_orders_len_mismatch() {
let schema = SchemaType::group_type_builder("schema").build().unwrap();
let schema_descr = SchemaDescriptor::new(Arc::new(schema));

let t_column_orders = Some(vec![TColumnOrder::TYPEORDER(TypeDefinedOrder::new())]);

ParquetMetaDataReader::parse_column_orders(t_column_orders, &schema_descr);
let res = ParquetMetaDataReader::parse_column_orders(t_column_orders, &schema_descr);
assert!(res.is_err());
assert!(format!("{:?}", res.unwrap_err()).contains("Column order length mismatch"));
}

#[test]
Expand Down
53 changes: 46 additions & 7 deletions parquet/src/file/serialized_reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ pub(crate) fn decode_page(
let is_sorted = dict_header.is_sorted.unwrap_or(false);
Page::DictionaryPage {
buf: buffer,
num_values: dict_header.num_values as u32,
num_values: dict_header.num_values.try_into()?,
encoding: Encoding::try_from(dict_header.encoding)?,
is_sorted,
}
Expand All @@ -446,7 +446,7 @@ pub(crate) fn decode_page(
.ok_or_else(|| ParquetError::General("Missing V1 data page header".to_string()))?;
Page::DataPage {
buf: buffer,
num_values: header.num_values as u32,
num_values: header.num_values.try_into()?,
encoding: Encoding::try_from(header.encoding)?,
def_level_encoding: Encoding::try_from(header.definition_level_encoding)?,
rep_level_encoding: Encoding::try_from(header.repetition_level_encoding)?,
Expand All @@ -460,12 +460,12 @@ pub(crate) fn decode_page(
let is_compressed = header.is_compressed.unwrap_or(true);
Page::DataPageV2 {
buf: buffer,
num_values: header.num_values as u32,
num_values: header.num_values.try_into()?,
encoding: Encoding::try_from(header.encoding)?,
num_nulls: header.num_nulls as u32,
num_rows: header.num_rows as u32,
def_levels_byte_len: header.definition_levels_byte_length as u32,
rep_levels_byte_len: header.repetition_levels_byte_length as u32,
num_nulls: header.num_nulls.try_into()?,
num_rows: header.num_rows.try_into()?,
def_levels_byte_len: header.definition_levels_byte_length.try_into()?,
rep_levels_byte_len: header.repetition_levels_byte_length.try_into()?,
is_compressed,
statistics: statistics::from_thrift(physical_type, header.statistics)?,
}
Expand Down Expand Up @@ -578,6 +578,27 @@ impl<R: ChunkReader> Iterator for SerializedPageReader<R> {
}
}

fn verify_page_header_len(header_len: usize, remaining_bytes: usize) -> Result<()> {
if header_len > remaining_bytes {
return Err(eof_err!("Invalid page header"));
}
Ok(())
}

fn verify_page_size(
compressed_size: i32,
uncompressed_size: i32,
remaining_bytes: usize,
) -> Result<()> {
// The page's compressed size should not exceed the remaining bytes that are
// available to read. The page's uncompressed size is the expected size
// after decompression, which can never be negative.
if compressed_size < 0 || compressed_size as usize > remaining_bytes || uncompressed_size < 0 {
return Err(eof_err!("Invalid page header"));
}
Ok(())
}

impl<R: ChunkReader> PageReader for SerializedPageReader<R> {
fn get_next_page(&mut self) -> Result<Option<Page>> {
loop {
Expand All @@ -596,10 +617,16 @@ impl<R: ChunkReader> PageReader for SerializedPageReader<R> {
*header
} else {
let (header_len, header) = read_page_header_len(&mut read)?;
verify_page_header_len(header_len, *remaining)?;
*offset += header_len;
*remaining -= header_len;
header
};
verify_page_size(
header.compressed_page_size,
header.uncompressed_page_size,
*remaining,
)?;
let data_len = header.compressed_page_size as usize;
*offset += data_len;
*remaining -= data_len;
Expand Down Expand Up @@ -683,6 +710,7 @@ impl<R: ChunkReader> PageReader for SerializedPageReader<R> {
} else {
let mut read = self.reader.get_read(*offset as u64)?;
let (header_len, header) = read_page_header_len(&mut read)?;
verify_page_header_len(header_len, *remaining_bytes)?;
*offset += header_len;
*remaining_bytes -= header_len;
let page_meta = if let Ok(page_meta) = (&header).try_into() {
Expand Down Expand Up @@ -733,12 +761,23 @@ impl<R: ChunkReader> PageReader for SerializedPageReader<R> {
next_page_header,
} => {
if let Some(buffered_header) = next_page_header.take() {
verify_page_size(
buffered_header.compressed_page_size,
buffered_header.uncompressed_page_size,
*remaining_bytes,
)?;
// The next page header has already been peeked, so just advance the offset
*offset += buffered_header.compressed_page_size as usize;
*remaining_bytes -= buffered_header.compressed_page_size as usize;
} else {
let mut read = self.reader.get_read(*offset as u64)?;
let (header_len, header) = read_page_header_len(&mut read)?;
verify_page_header_len(header_len, *remaining_bytes)?;
verify_page_size(
header.compressed_page_size,
header.uncompressed_page_size,
*remaining_bytes,
)?;
let data_page_size = header.compressed_page_size as usize;
*offset += header_len + data_page_size;
*remaining_bytes -= header_len + data_page_size;
Expand Down
26 changes: 26 additions & 0 deletions parquet/src/file/statistics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,32 @@ pub fn from_thrift(
stats.max_value
};

fn check_len(min: &Option<Vec<u8>>, max: &Option<Vec<u8>>, len: usize) -> Result<()> {
if let Some(min) = min {
if min.len() < len {
return Err(ParquetError::General(
"Insufficient bytes to parse max statistic".to_string(),
));
}
}
if let Some(max) = max {
if max.len() < len {
return Err(ParquetError::General(
"Insufficient bytes to parse max statistic".to_string(),
));
}
}
Ok(())
}

match physical_type {
Type::BOOLEAN => check_len(&min, &max, 1),
Type::INT32 | Type::FLOAT => check_len(&min, &max, 4),
Type::INT64 | Type::DOUBLE => check_len(&min, &max, 8),
Type::INT96 => check_len(&min, &max, 12),
_ => Ok(()),
}?;

// Values are encoded using PLAIN encoding definition, except that
// variable-length byte arrays do not include a length prefix.
//
Expand Down
22 changes: 21 additions & 1 deletion parquet/src/schema/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -533,7 +533,8 @@ impl<'a> PrimitiveTypeBuilder<'a> {
}
}
PhysicalType::FIXED_LEN_BYTE_ARRAY => {
let max_precision = (2f64.powi(8 * self.length - 1) - 1f64).log10().floor() as i32;
let length = self.length.checked_mul(8).unwrap_or(i32::MAX);
let max_precision = (2f64.powi(length - 1) - 1f64).log10().floor() as i32;

if self.precision > max_precision {
return Err(general_err!(
Expand Down Expand Up @@ -1122,9 +1123,25 @@ pub fn from_thrift(elements: &[SchemaElement]) -> Result<TypePtr> {
));
}

if !schema_nodes[0].is_group() {
return Err(general_err!("Expected root node to be a group type"));
}

Ok(schema_nodes.remove(0))
}

/// Checks if the logical type is valid.
fn check_logical_type(logical_type: &Option<LogicalType>) -> Result<()> {
if let Some(LogicalType::Integer { bit_width, .. }) = *logical_type {
if bit_width != 8 && bit_width != 16 && bit_width != 32 && bit_width != 64 {
return Err(general_err!(
"Bit width must be 8, 16, 32, or 64 for Integer logical type"
));
}
}
Ok(())
}

/// Constructs a new Type from the `elements`, starting at index `index`.
/// The first result is the starting index for the next Type after this one. If it is
/// equal to `elements.len()`, then this Type is the last one.
Expand All @@ -1149,6 +1166,9 @@ fn from_thrift_helper(elements: &[SchemaElement], index: usize) -> Result<(usize
.logical_type
.as_ref()
.map(|value| LogicalType::from(value.clone()));

check_logical_type(&logical_type)?;

let field_id = elements[index].field_id;
match elements[index].num_children {
// From parquet-format:
Expand Down
35 changes: 29 additions & 6 deletions parquet/src/thrift.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ impl<'a> TCompactSliceInputProtocol<'a> {
let mut shift = 0;
loop {
let byte = self.read_byte()?;
in_progress |= ((byte & 0x7F) as u64) << shift;
in_progress |= ((byte & 0x7F) as u64).wrapping_shl(shift);
shift += 7;
if byte & 0x80 == 0 {
return Ok(in_progress);
Expand Down Expand Up @@ -96,13 +96,22 @@ impl<'a> TCompactSliceInputProtocol<'a> {
}
}

macro_rules! thrift_unimplemented {
() => {
Err(thrift::Error::Protocol(thrift::ProtocolError {
kind: thrift::ProtocolErrorKind::NotImplemented,
message: "not implemented".to_string(),
}))
};
}

impl TInputProtocol for TCompactSliceInputProtocol<'_> {
fn read_message_begin(&mut self) -> thrift::Result<TMessageIdentifier> {
unimplemented!()
}

fn read_message_end(&mut self) -> thrift::Result<()> {
unimplemented!()
thrift_unimplemented!()
}

fn read_struct_begin(&mut self) -> thrift::Result<Option<TStructIdentifier>> {
Expand Down Expand Up @@ -147,7 +156,21 @@ impl TInputProtocol for TCompactSliceInputProtocol<'_> {
),
_ => {
if field_delta != 0 {
self.last_read_field_id += field_delta as i16;
self.last_read_field_id = self
.last_read_field_id
.checked_add(field_delta as i16)
.map_or_else(
|| {
Err(thrift::Error::Protocol(thrift::ProtocolError {
kind: thrift::ProtocolErrorKind::InvalidData,
message: format!(
"cannot add {} to {}",
field_delta, self.last_read_field_id
),
}))
},
Ok,
)?;
} else {
self.last_read_field_id = self.read_i16()?;
};
Expand Down Expand Up @@ -226,15 +249,15 @@ impl TInputProtocol for TCompactSliceInputProtocol<'_> {
}

fn read_set_begin(&mut self) -> thrift::Result<TSetIdentifier> {
unimplemented!()
thrift_unimplemented!()
}

fn read_set_end(&mut self) -> thrift::Result<()> {
unimplemented!()
thrift_unimplemented!()
}

fn read_map_begin(&mut self) -> thrift::Result<TMapIdentifier> {
unimplemented!()
thrift_unimplemented!()
}

fn read_map_end(&mut self) -> thrift::Result<()> {
Expand Down
2 changes: 1 addition & 1 deletion parquet/tests/arrow_reader/bad_data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ fn test_arrow_rs_gh_6229_dict_header() {
let err = read_file("ARROW-RS-GH-6229-DICTHEADER.parquet").unwrap_err();
assert_eq!(
err.to_string(),
"External: Parquet argument error: EOF: eof decoding byte array"
"External: Parquet argument error: Parquet error: Integer overflow: out of range integral type conversion attempted"
);
}

Expand Down
Loading