From 03eba911ab6225e4d27ec5c8bf77f1628879249f Mon Sep 17 00:00:00 2001 From: Ishan Jain Date: Sat, 30 Jul 2022 21:04:12 +0530 Subject: [PATCH] Traverse btree and read rows --- src/header.rs | 41 ++++++++++++----- src/main.rs | 107 ++++++++++++++++++++++++++++++++------------ src/record.rs | 121 +++++++++++++++++++++++++++++++++++++++++++++----- src/schema.rs | 22 ++++----- 4 files changed, 232 insertions(+), 59 deletions(-) diff --git a/src/header.rs b/src/header.rs index 9e9af62..172ab47 100644 --- a/src/header.rs +++ b/src/header.rs @@ -1,7 +1,7 @@ use anyhow::{bail, Result}; use std::convert::TryInto; -#[derive(Debug)] +#[derive(Debug, Eq, PartialEq)] pub enum BTreePage { InteriorIndex = 2, InteriorTable = 5, @@ -16,11 +16,12 @@ pub struct PageHeader { pub number_of_cells: u16, pub start_of_content_area: u16, pub fragmented_free_bytes: u8, + pub right_most_pointer: Option, } impl PageHeader { /// Parses a page header stream into a page header - pub fn parse(stream: &[u8]) -> Result { + pub fn parse(stream: &[u8]) -> Result<(usize, Self)> { let page_type = match stream[0] { 2 => BTreePage::InteriorIndex, 5 => BTreePage::InteriorTable, @@ -32,13 +33,33 @@ impl PageHeader { let number_of_cells = u16::from_be_bytes(stream[3..5].try_into()?); let start_of_content_area = u16::from_be_bytes(stream[5..7].try_into()?); let fragmented_free_bytes = stream[7]; - let header = PageHeader { - page_type, - first_free_block_start, - number_of_cells, - start_of_content_area, - fragmented_free_bytes, - }; - Ok(header) + + if page_type == BTreePage::InteriorTable { + Ok(( + 12, + PageHeader { + page_type, + first_free_block_start, + number_of_cells, + start_of_content_area, + fragmented_free_bytes, + right_most_pointer: Some(u32::from_be_bytes([ + stream[8], stream[9], stream[10], stream[11], + ])), + }, + )) + } else { + Ok(( + 8, + PageHeader { + page_type, + first_free_block_start, + number_of_cells, + start_of_content_area, + fragmented_free_bytes, + right_most_pointer: None, + }, + )) + } } } diff --git a/src/main.rs b/src/main.rs index 25116a3..72e54e5 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,6 +1,8 @@ use anyhow::{bail, Error, Result}; use once_cell::sync::Lazy; use regex::{Regex, RegexBuilder}; +use sqlite_starter_rust::header::BTreePage; +use sqlite_starter_rust::record::ColumnValue; use sqlite_starter_rust::{ header::PageHeader, record::parse_record, schema::Schema, varint::parse_varint, }; @@ -47,7 +49,7 @@ fn main() -> Result<()> { match command.as_str() { ".dbinfo" => { // Parse page header from database - let page_header = PageHeader::parse(&database[100..108])?; + let (_, page_header) = PageHeader::parse(&database[100..108])?; // Obtain all cell pointers let cell_pointers = database[108..] @@ -77,7 +79,7 @@ fn main() -> Result<()> { ".tables" => { // Parse page header from database - let page_header = PageHeader::parse(&database[100..108])?; + let (_, page_header) = PageHeader::parse(&database[100..108])?; // Obtain all cell pointers let cell_pointers = database[108..] @@ -119,53 +121,100 @@ fn main() -> Result<()> { } } +fn parse_page<'a>( + database: &'a [u8], + column_map: &'a HashMap<&str, usize>, + page_size: usize, + page_num: usize, +) -> Option>)> + 'a>> { + let table_page_offset = page_size * (page_num - 1); + let (read, page_header) = + PageHeader::parse(&database[table_page_offset..table_page_offset + 12]).unwrap(); + + let cell_pointers = database[table_page_offset + read..] + .chunks_exact(2) + .take(page_header.number_of_cells.into()) + .map(|bytes| u16::from_be_bytes(bytes.try_into().unwrap())); + + match page_header.page_type { + BTreePage::InteriorIndex => todo!(), + BTreePage::InteriorTable => { + let rows = cell_pointers + .into_iter() + .map(move |cp| { + let stream = &database[table_page_offset + cp as usize..]; + let left_child_id = + u32::from_be_bytes([stream[0], stream[1], stream[2], stream[3]]); + + let (_rowid, _offset) = parse_varint(&stream[4..]); + + parse_page(database, column_map, page_size, left_child_id as usize) + }) + .flatten() + .flatten(); + + Some(Box::new(rows)) + } + BTreePage::LeafIndex => todo!(), + BTreePage::LeafTable => { + let rows = cell_pointers.into_iter().map(move |cp| { + let stream = &database[table_page_offset + cp as usize..]; + let (total, offset) = parse_varint(stream); + let (rowid, read_bytes) = parse_varint(&stream[offset..]); + + ( + rowid, + parse_record( + &stream[offset + read_bytes..offset + read_bytes + total as usize], + column_map.len(), + ) + .unwrap(), + ) + }); + + Some(Box::new(rows)) + } + } +} + fn read_columns(query: &str, db_header: DBHeader, database: &[u8]) -> Result<(), Error> { let (columns, table, where_clause) = read_column_and_table(query); // Assume it's valid SQL - let schema = db_header .schemas - .into_iter() + .iter() .find(|schema| schema.table_name == table) .unwrap(); let column_map = find_column_positions(&schema.sql); - let table_page_offset = db_header.page_size as usize * (schema.root_page as usize - 1); - let page_header = - PageHeader::parse(&database[table_page_offset..table_page_offset + 8]).unwrap(); + let rows = parse_page( + database, + &column_map, + db_header.page_size as usize, + schema.root_page as usize, + ); - let cell_pointers = database[table_page_offset + 8..] - .chunks_exact(2) - .take(page_header.number_of_cells.into()) - .map(|bytes| u16::from_be_bytes(bytes.try_into().unwrap())); - - let rows = cell_pointers.into_iter().map(|cp| { - let stream = &database[table_page_offset + cp as usize..]; - let (_, offset) = parse_varint(stream); - let (_, read_bytes) = parse_varint(&stream[offset..]); - - parse_record(&stream[offset + read_bytes..], column_map.len()).unwrap() - }); - - for row in rows { + for (rowid, row) in rows.unwrap() { let mut output = String::new(); if let Some(wc) = where_clause { let colidx = *column_map.get(wc.0).unwrap(); - let row_pol = String::from_utf8_lossy(&row[colidx]); + let row_pol = row[colidx].read_string(); if row_pol != wc.1 { continue; } } - for column in columns.iter() { - let cpos = *column_map.get(column).unwrap(); - let s = String::from_utf8_lossy(&row[cpos]); - - output.push_str(&s); + for &column in columns.iter() { + if column == "id" { + output.push_str(&rowid.to_string()); + } else { + let cpos = *column_map.get(column).unwrap(); + output.push_str(&row[cpos].to_string()); + } output.push('|'); } @@ -186,7 +235,7 @@ struct DBHeader { fn read_db_header(database: &[u8]) -> Result { let db_page_size = u16::from_be_bytes([database[16], database[17]]); // Parse page header from database - let page_header = PageHeader::parse(&database[100..108])?; + let (_, page_header) = PageHeader::parse(&database[100..108])?; // Obtain all cell pointers let cell_pointers = database[108..] @@ -223,7 +272,7 @@ fn count_rows_in_table(query: &str, db_header: DBHeader, database: &[u8]) -> Res .unwrap(); let table_page_offset = db_header.page_size as usize * (schema.root_page as usize - 1); - let page_header = + let (_, page_header) = PageHeader::parse(&database[table_page_offset..table_page_offset + 8]).unwrap(); println!("{}", page_header.number_of_cells); diff --git a/src/record.rs b/src/record.rs index 55d051b..7096b1a 100644 --- a/src/record.rs +++ b/src/record.rs @@ -1,9 +1,11 @@ +use std::fmt::Display; + use crate::varint::parse_varint; use anyhow::{bail, Result}; /// Reads SQLite's "Record Format" as mentioned here: /// [record_format](https://www.sqlite.org/fileformat.html#record_format) -pub fn parse_record(stream: &[u8], column_count: usize) -> Result>> { +pub fn parse_record(stream: &[u8], column_count: usize) -> Result> { // Parse number of bytes in header, and use bytes_read as offset let (_, mut offset) = parse_varint(stream); @@ -19,28 +21,127 @@ pub fn parse_record(stream: &[u8], column_count: usize) -> Result>> let mut record = vec![]; for serial_type in serial_types { let column = parse_column_value(&stream[offset..], serial_type as usize)?; - offset += column.len(); + offset += column.length(); record.push(column); } Ok(record) } -fn parse_column_value(stream: &[u8], serial_type: usize) -> Result> { - let column_value = match serial_type { - 0 => vec![], +#[derive(Debug)] +pub enum ColumnValue<'a> { + Null, + U8(u8), + U16(u16), + U24(u32), + U32(u32), + U48(u64), + U64(u64), + FP64(f64), + False, + True, + Blob(&'a [u8]), + Text(String), +} + +impl<'a> ColumnValue<'a> { + pub fn length(&self) -> usize { + match self { + ColumnValue::Null => 0, + ColumnValue::U8(_) => 1, + ColumnValue::U16(_) => 2, + ColumnValue::U24(_) => 3, + ColumnValue::U32(_) => 4, + ColumnValue::U48(_) => 6, + ColumnValue::U64(_) => 8, + ColumnValue::FP64(_) => 8, + ColumnValue::False => 0, + ColumnValue::True => 0, + ColumnValue::Blob(v) => v.len(), + ColumnValue::Text(v) => v.len(), + } + } + + pub fn read_string(&self) -> String { + match self { + ColumnValue::Text(v) => v.clone(), + ColumnValue::Null => String::new(), + _ => unreachable!(), + } + } + + pub fn read_u8(&self) -> u8 { + if let ColumnValue::U8(v) = self { + *v + } else { + unreachable!() + } + } +} + +impl<'a> Display for ColumnValue<'a> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ColumnValue::Null => f.write_str(""), + ColumnValue::U8(v) => f.write_str(&v.to_string()), + ColumnValue::U16(v) => f.write_str(&v.to_string()), + ColumnValue::U24(v) => f.write_str(&v.to_string()), + ColumnValue::U32(v) => f.write_str(&v.to_string()), + ColumnValue::U48(v) => f.write_str(&v.to_string()), + ColumnValue::U64(v) => f.write_str(&v.to_string()), + ColumnValue::FP64(v) => f.write_str(&v.to_string()), + ColumnValue::False => f.write_str("false"), + ColumnValue::True => f.write_str("true"), + ColumnValue::Blob(v) => f.write_fmt(format_args!("{:?}", v)), + ColumnValue::Text(v) => f.write_str(v), + } + } +} + +fn parse_column_value(stream: &[u8], serial_type: usize) -> Result { + Ok(match serial_type { + 0 => ColumnValue::Null, // 8 bit twos-complement integer - 1 => vec![stream[0]], + 1 => ColumnValue::U8(stream[0]), + 2 => { + let value = (!(stream[0] as u16) << 8) + !stream[1] as u16 + 1; + + ColumnValue::U16(value) + } + + 3 => { + let value = + (!(stream[0] as u32) << 16) + (!(stream[1] as u32) << 8) + !stream[2] as u32 + 1; + + ColumnValue::U24(value) + } + + 4 => { + let value = (!(stream[0] as u32) << 24) + + (!(stream[0] as u32) << 16) + + (!(stream[1] as u32) << 8) + + !stream[2] as u32 + + 1; + + ColumnValue::U32(value) + } + + 8 => ColumnValue::False, + 9 => ColumnValue::True, + // Text encoding n if serial_type >= 12 && serial_type % 2 == 0 => { let n_bytes = (n - 12) / 2; - stream[0..n_bytes as usize].to_vec() + ColumnValue::Blob(&stream[0..n_bytes as usize]) } n if serial_type >= 13 && serial_type % 2 == 1 => { let n_bytes = (n - 13) / 2; - stream[0..n_bytes as usize].to_vec() + let a = &stream[0..n_bytes as usize]; + + let s = String::from_utf8_lossy(a); + + ColumnValue::Text(s.to_string()) } _ => bail!("Invalid serial_type: {}", serial_type), - }; - Ok(column_value) + }) } diff --git a/src/schema.rs b/src/schema.rs index 81b638a..5467521 100644 --- a/src/schema.rs +++ b/src/schema.rs @@ -1,3 +1,5 @@ +use crate::record::ColumnValue; + #[derive(Debug)] pub struct Schema { pub kind: String, @@ -9,20 +11,20 @@ pub struct Schema { impl Schema { /// Parses a record into a schema - pub fn parse(record: Vec>) -> Option { + pub fn parse(record: Vec) -> Option { let mut items = record.into_iter(); - let kind = items.next()?; - let name = items.next()?; - let table_name = items.next()?; - let root_page = *items.next()?.get(0)?; - let sql = items.next()?; + let kind = items.next()?.read_string(); + let name = items.next()?.read_string(); + let table_name = items.next()?.read_string(); + let root_page = items.next()?.read_u8(); + let sql = items.next()?.read_string(); let schema = Self { - kind: String::from_utf8_lossy(&kind).to_string(), - name: String::from_utf8_lossy(&name).to_string(), - table_name: String::from_utf8_lossy(&table_name).to_string(), + kind, + name, + table_name, root_page, - sql: String::from_utf8_lossy(&sql).to_string(), + sql, }; Some(schema) }