diff --git a/src/header.rs b/src/header.rs index 172ab47..02f459d 100644 --- a/src/header.rs +++ b/src/header.rs @@ -34,8 +34,8 @@ impl PageHeader { let start_of_content_area = u16::from_be_bytes(stream[5..7].try_into()?); let fragmented_free_bytes = stream[7]; - if page_type == BTreePage::InteriorTable { - Ok(( + match page_type { + BTreePage::InteriorIndex | BTreePage::InteriorTable => Ok(( 12, PageHeader { page_type, @@ -47,9 +47,9 @@ impl PageHeader { stream[8], stream[9], stream[10], stream[11], ])), }, - )) - } else { - Ok(( + )), + + BTreePage::LeafIndex | BTreePage::LeafTable => Ok(( 8, PageHeader { page_type, @@ -59,7 +59,7 @@ impl PageHeader { fragmented_free_bytes, right_most_pointer: None, }, - )) + )), } } } diff --git a/src/main.rs b/src/main.rs index a84fa2b..493b59a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -6,7 +6,7 @@ use sqlite_starter_rust::record::ColumnValue; use sqlite_starter_rust::{ header::PageHeader, record::parse_record, schema::Schema, varint::parse_varint, }; -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::convert::TryInto; use std::fs::File; use std::io::prelude::*; @@ -46,7 +46,7 @@ fn main() -> Result<()> { // Parse command and act accordingly let command = &args[2]; - match command.as_str() { + match command.as_str().trim() { ".dbinfo" => { // Parse page header from database let (_, page_header) = PageHeader::parse(&database[100..108])?; @@ -58,7 +58,6 @@ fn main() -> Result<()> { .map(|bytes| u16::from_be_bytes(bytes.try_into().unwrap())); // Obtain all records from column 5 - #[allow(unused_variables)] let schemas = cell_pointers .into_iter() .map(|cell_pointer| { @@ -88,7 +87,6 @@ fn main() -> Result<()> { .map(|bytes| u16::from_be_bytes(bytes.try_into().unwrap())); // Obtain all records from column 5 - #[allow(unused_variables)] let schemas = cell_pointers .into_iter() .map(|cell_pointer| { @@ -110,6 +108,15 @@ fn main() -> Result<()> { Ok(()) } + v if v.contains("companies") => { + let db_header = read_db_header(&database)?; + + // Traverse the index + read_index(&database, v, &db_header); + + Ok(()) + } + v => { let db_header = read_db_header(&database)?; if v.to_lowercase().contains("count(*)") { @@ -121,6 +128,70 @@ fn main() -> Result<()> { } } +fn read_index(database: &[u8], query: &str, db_header: &DBHeader) { + let (columns, table, where_clause) = read_column_and_table(query); + + let schema = db_header + .schemas + .iter() + .find(|schema| schema.table_name == table) + .unwrap(); + + let column_map = find_column_positions(&schema.sql); + + // Assume it's valid SQL + let index_schema = db_header + .schemas + .iter() + .find(|schema| schema.name == "idx_companies_country") + .unwrap(); + + let rows = parse_page( + database, + &db_header, + &column_map, + db_header.page_size as usize * (index_schema.root_page as usize - 1), + ); + + let rowids: HashSet = rows + .unwrap() + .filter_map(|(rowid, row)| { + if row[0].to_string() == where_clause.unwrap().1 { + Some(rowid) + } else { + None + } + }) + .collect(); + + let rows = parse_page( + database, + &db_header, + &column_map, + db_header.page_size as usize * (schema.root_page as usize - 1), + ) + .unwrap() + .filter(|(rowid, _)| rowids.contains(rowid)); + + for (rowid, row) in rows { + let mut output = String::new(); + + for &column in columns.iter() { + if column == "id" { + output.push_str(&rowid.to_string()); + } else { + let cpos = *column_map.get(column).unwrap(); + output.push_str(&row[cpos].to_string()); + } + output.push('|'); + } + + let output = output.trim_end_matches(|c| c == '|'); + + println!("{}", output); + } +} + fn parse_page<'a>( database: &'a [u8], db_header: &'a DBHeader, @@ -136,7 +207,6 @@ fn parse_page<'a>( .map(|bytes| u16::from_be_bytes(bytes.try_into().unwrap())); match page_header.page_type { - BTreePage::InteriorIndex => todo!(), BTreePage::InteriorTable => { let rows = cell_pointers .into_iter() @@ -172,7 +242,6 @@ fn parse_page<'a>( Some(Box::new(rows)) } } - BTreePage::LeafIndex => todo!(), BTreePage::LeafTable => { let rows = cell_pointers.into_iter().map(move |cp| { let stream = &database[table_page_offset + cp as usize..]; @@ -189,11 +258,84 @@ fn parse_page<'a>( ) }); + Some(Box::new(rows)) + } + BTreePage::InteriorIndex => { + let rows = cell_pointers + .into_iter() + .filter_map(move |cp| { + let stream = &database[table_page_offset + cp as usize..]; + let left_child_id = + u32::from_be_bytes([stream[0], stream[1], stream[2], stream[3]]); + let (payload_size, offset) = parse_varint(&stream[4..]); + /* + * + * There is some payload here but it only contains the key so we are just going + * to ignore it + */ + let record = parse_record(&stream[offset + 4..offset + 4 + payload_size], 2); + let record = record.unwrap(); + + Some( + parse_page( + database, + db_header, + column_map, + db_header.page_size as usize * (left_child_id as usize - 1), + ) + .unwrap() + .chain(std::iter::once((record[1].read_usize(), record))), + ) + + // println!( + // "left child id = {} payload size = {} offset = {} column count = {} country = {}", + // left_child_id, + // payload_size, + // offset, + // column_map.len(),country + // ); + // + // TODO(ishan): Read number of bytes of payload. + // Read any over flow pages properly + //parse_record( + // &stream[offset + 4..offset + payload_size + 4], + // column_map.len(), + //) + //.unwrap(), + }) + .flatten(); + + if let Some(rp) = page_header.right_most_pointer { + Some(Box::new( + rows.chain( + parse_page( + database, + db_header, + column_map, + db_header.page_size as usize * (rp as usize - 1), + ) + .unwrap(), + ), + )) + } else { + Some(Box::new(rows)) + } + } + + BTreePage::LeafIndex => { + let rows = cell_pointers.into_iter().filter_map(move |cp| { + let stream = &database[table_page_offset + cp as usize..]; + let (payload_size, offset) = parse_varint(&stream); + let record = parse_record(&stream[offset..offset + payload_size], 2); + let record = record.unwrap(); + + Some((record[1].read_usize(), record)) + }); + Some(Box::new(rows)) } } } - fn read_columns(query: &str, db_header: DBHeader, database: &[u8]) -> Result<(), Error> { let (columns, table, where_clause) = read_column_and_table(query); // Assume it's valid SQL @@ -218,7 +360,7 @@ fn read_columns(query: &str, db_header: DBHeader, database: &[u8]) -> Result<(), if let Some(wc) = where_clause { let colidx = *column_map.get(wc.0).unwrap(); - let row_pol = row[colidx].read_string(); + let row_pol = row[colidx].to_string(); if row_pol != wc.1 { continue; diff --git a/src/record.rs b/src/record.rs index 7096b1a..fec11d3 100644 --- a/src/record.rs +++ b/src/record.rs @@ -1,13 +1,12 @@ -use std::fmt::Display; - use crate::varint::parse_varint; use anyhow::{bail, Result}; +use std::fmt::Display; /// Reads SQLite's "Record Format" as mentioned here: /// [record_format](https://www.sqlite.org/fileformat.html#record_format) pub fn parse_record(stream: &[u8], column_count: usize) -> Result> { // Parse number of bytes in header, and use bytes_read as offset - let (_, mut offset) = parse_varint(stream); + let (header_size, mut offset) = parse_varint(stream); // Read each varint into serial types and modify the offset let mut serial_types = vec![]; @@ -17,6 +16,7 @@ pub fn parse_record(stream: &[u8], column_count: usize) -> Result Result { Null, U8(u8), @@ -41,7 +41,7 @@ pub enum ColumnValue<'a> { False, True, Blob(&'a [u8]), - Text(String), + Text(&'a [u8]), } impl<'a> ColumnValue<'a> { @@ -62,14 +62,6 @@ impl<'a> ColumnValue<'a> { } } - pub fn read_string(&self) -> String { - match self { - ColumnValue::Text(v) => v.clone(), - ColumnValue::Null => String::new(), - _ => unreachable!(), - } - } - pub fn read_u8(&self) -> u8 { if let ColumnValue::U8(v) = self { *v @@ -77,6 +69,18 @@ impl<'a> ColumnValue<'a> { unreachable!() } } + + pub fn read_usize(&self) -> usize { + match self { + ColumnValue::U8(v) => *v as usize, + ColumnValue::U16(v) => *v as usize, + ColumnValue::U24(v) => *v as usize, + ColumnValue::U32(v) => *v as usize, + ColumnValue::U48(v) => *v as usize, + ColumnValue::U64(v) => *v as usize, + _ => unreachable!(), + } + } } impl<'a> Display for ColumnValue<'a> { @@ -93,7 +97,7 @@ impl<'a> Display for ColumnValue<'a> { ColumnValue::False => f.write_str("false"), ColumnValue::True => f.write_str("true"), ColumnValue::Blob(v) => f.write_fmt(format_args!("{:?}", v)), - ColumnValue::Text(v) => f.write_str(v), + ColumnValue::Text(v) => f.write_str(&String::from_utf8(v.to_vec()).unwrap()), } } } @@ -103,28 +107,13 @@ fn parse_column_value(stream: &[u8], serial_type: usize) -> Result 0 => ColumnValue::Null, // 8 bit twos-complement integer 1 => ColumnValue::U8(stream[0]), - 2 => { - let value = (!(stream[0] as u16) << 8) + !stream[1] as u16 + 1; + 2 => ColumnValue::U16(u16::from_be_bytes([stream[0], stream[1]])), - ColumnValue::U16(value) - } + 3 => ColumnValue::U24(u32::from_be_bytes([0, stream[0], stream[1], stream[2]])), - 3 => { - let value = - (!(stream[0] as u32) << 16) + (!(stream[1] as u32) << 8) + !stream[2] as u32 + 1; - - ColumnValue::U24(value) - } - - 4 => { - let value = (!(stream[0] as u32) << 24) - + (!(stream[0] as u32) << 16) - + (!(stream[1] as u32) << 8) - + !stream[2] as u32 - + 1; - - ColumnValue::U32(value) - } + 4 => ColumnValue::U32(u32::from_be_bytes([ + stream[0], stream[1], stream[2], stream[3], + ])), 8 => ColumnValue::False, 9 => ColumnValue::True, @@ -138,9 +127,7 @@ fn parse_column_value(stream: &[u8], serial_type: usize) -> Result let n_bytes = (n - 13) / 2; let a = &stream[0..n_bytes as usize]; - let s = String::from_utf8_lossy(a); - - ColumnValue::Text(s.to_string()) + ColumnValue::Text(a) } _ => bail!("Invalid serial_type: {}", serial_type), }) diff --git a/src/schema.rs b/src/schema.rs index 5467521..32c5503 100644 --- a/src/schema.rs +++ b/src/schema.rs @@ -13,11 +13,11 @@ impl Schema { /// Parses a record into a schema pub fn parse(record: Vec) -> Option { let mut items = record.into_iter(); - let kind = items.next()?.read_string(); - let name = items.next()?.read_string(); - let table_name = items.next()?.read_string(); + let kind = items.next()?.to_string(); + let name = items.next()?.to_string(); + let table_name = items.next()?.to_string(); let root_page = items.next()?.read_u8(); - let sql = items.next()?.read_string(); + let sql = items.next()?.to_string(); let schema = Self { kind,