1
0

Traverse btree and read rows

This commit is contained in:
Ishan Jain 2022-07-30 21:04:12 +05:30
parent b346c3e8e0
commit 03eba911ab
4 changed files with 232 additions and 59 deletions

View File

@ -1,7 +1,7 @@
use anyhow::{bail, Result}; use anyhow::{bail, Result};
use std::convert::TryInto; use std::convert::TryInto;
#[derive(Debug)] #[derive(Debug, Eq, PartialEq)]
pub enum BTreePage { pub enum BTreePage {
InteriorIndex = 2, InteriorIndex = 2,
InteriorTable = 5, InteriorTable = 5,
@ -16,11 +16,12 @@ pub struct PageHeader {
pub number_of_cells: u16, pub number_of_cells: u16,
pub start_of_content_area: u16, pub start_of_content_area: u16,
pub fragmented_free_bytes: u8, pub fragmented_free_bytes: u8,
pub right_most_pointer: Option<u32>,
} }
impl PageHeader { impl PageHeader {
/// Parses a page header stream into a page header /// Parses a page header stream into a page header
pub fn parse(stream: &[u8]) -> Result<Self> { pub fn parse(stream: &[u8]) -> Result<(usize, Self)> {
let page_type = match stream[0] { let page_type = match stream[0] {
2 => BTreePage::InteriorIndex, 2 => BTreePage::InteriorIndex,
5 => BTreePage::InteriorTable, 5 => BTreePage::InteriorTable,
@ -32,13 +33,33 @@ impl PageHeader {
let number_of_cells = u16::from_be_bytes(stream[3..5].try_into()?); let number_of_cells = u16::from_be_bytes(stream[3..5].try_into()?);
let start_of_content_area = u16::from_be_bytes(stream[5..7].try_into()?); let start_of_content_area = u16::from_be_bytes(stream[5..7].try_into()?);
let fragmented_free_bytes = stream[7]; let fragmented_free_bytes = stream[7];
let header = PageHeader {
page_type, if page_type == BTreePage::InteriorTable {
first_free_block_start, Ok((
number_of_cells, 12,
start_of_content_area, PageHeader {
fragmented_free_bytes, page_type,
}; first_free_block_start,
Ok(header) number_of_cells,
start_of_content_area,
fragmented_free_bytes,
right_most_pointer: Some(u32::from_be_bytes([
stream[8], stream[9], stream[10], stream[11],
])),
},
))
} else {
Ok((
8,
PageHeader {
page_type,
first_free_block_start,
number_of_cells,
start_of_content_area,
fragmented_free_bytes,
right_most_pointer: None,
},
))
}
} }
} }

View File

@ -1,6 +1,8 @@
use anyhow::{bail, Error, Result}; use anyhow::{bail, Error, Result};
use once_cell::sync::Lazy; use once_cell::sync::Lazy;
use regex::{Regex, RegexBuilder}; use regex::{Regex, RegexBuilder};
use sqlite_starter_rust::header::BTreePage;
use sqlite_starter_rust::record::ColumnValue;
use sqlite_starter_rust::{ use sqlite_starter_rust::{
header::PageHeader, record::parse_record, schema::Schema, varint::parse_varint, header::PageHeader, record::parse_record, schema::Schema, varint::parse_varint,
}; };
@ -47,7 +49,7 @@ fn main() -> Result<()> {
match command.as_str() { match command.as_str() {
".dbinfo" => { ".dbinfo" => {
// Parse page header from database // Parse page header from database
let page_header = PageHeader::parse(&database[100..108])?; let (_, page_header) = PageHeader::parse(&database[100..108])?;
// Obtain all cell pointers // Obtain all cell pointers
let cell_pointers = database[108..] let cell_pointers = database[108..]
@ -77,7 +79,7 @@ fn main() -> Result<()> {
".tables" => { ".tables" => {
// Parse page header from database // Parse page header from database
let page_header = PageHeader::parse(&database[100..108])?; let (_, page_header) = PageHeader::parse(&database[100..108])?;
// Obtain all cell pointers // Obtain all cell pointers
let cell_pointers = database[108..] let cell_pointers = database[108..]
@ -119,53 +121,100 @@ fn main() -> Result<()> {
} }
} }
fn parse_page<'a>(
database: &'a [u8],
column_map: &'a HashMap<&str, usize>,
page_size: usize,
page_num: usize,
) -> Option<Box<dyn Iterator<Item = (usize, Vec<ColumnValue<'a>>)> + 'a>> {
let table_page_offset = page_size * (page_num - 1);
let (read, page_header) =
PageHeader::parse(&database[table_page_offset..table_page_offset + 12]).unwrap();
let cell_pointers = database[table_page_offset + read..]
.chunks_exact(2)
.take(page_header.number_of_cells.into())
.map(|bytes| u16::from_be_bytes(bytes.try_into().unwrap()));
match page_header.page_type {
BTreePage::InteriorIndex => todo!(),
BTreePage::InteriorTable => {
let rows = cell_pointers
.into_iter()
.map(move |cp| {
let stream = &database[table_page_offset + cp as usize..];
let left_child_id =
u32::from_be_bytes([stream[0], stream[1], stream[2], stream[3]]);
let (_rowid, _offset) = parse_varint(&stream[4..]);
parse_page(database, column_map, page_size, left_child_id as usize)
})
.flatten()
.flatten();
Some(Box::new(rows))
}
BTreePage::LeafIndex => todo!(),
BTreePage::LeafTable => {
let rows = cell_pointers.into_iter().map(move |cp| {
let stream = &database[table_page_offset + cp as usize..];
let (total, offset) = parse_varint(stream);
let (rowid, read_bytes) = parse_varint(&stream[offset..]);
(
rowid,
parse_record(
&stream[offset + read_bytes..offset + read_bytes + total as usize],
column_map.len(),
)
.unwrap(),
)
});
Some(Box::new(rows))
}
}
}
fn read_columns(query: &str, db_header: DBHeader, database: &[u8]) -> Result<(), Error> { fn read_columns(query: &str, db_header: DBHeader, database: &[u8]) -> Result<(), Error> {
let (columns, table, where_clause) = read_column_and_table(query); let (columns, table, where_clause) = read_column_and_table(query);
// Assume it's valid SQL // Assume it's valid SQL
let schema = db_header let schema = db_header
.schemas .schemas
.into_iter() .iter()
.find(|schema| schema.table_name == table) .find(|schema| schema.table_name == table)
.unwrap(); .unwrap();
let column_map = find_column_positions(&schema.sql); let column_map = find_column_positions(&schema.sql);
let table_page_offset = db_header.page_size as usize * (schema.root_page as usize - 1); let rows = parse_page(
let page_header = database,
PageHeader::parse(&database[table_page_offset..table_page_offset + 8]).unwrap(); &column_map,
db_header.page_size as usize,
schema.root_page as usize,
);
let cell_pointers = database[table_page_offset + 8..] for (rowid, row) in rows.unwrap() {
.chunks_exact(2)
.take(page_header.number_of_cells.into())
.map(|bytes| u16::from_be_bytes(bytes.try_into().unwrap()));
let rows = cell_pointers.into_iter().map(|cp| {
let stream = &database[table_page_offset + cp as usize..];
let (_, offset) = parse_varint(stream);
let (_, read_bytes) = parse_varint(&stream[offset..]);
parse_record(&stream[offset + read_bytes..], column_map.len()).unwrap()
});
for row in rows {
let mut output = String::new(); let mut output = String::new();
if let Some(wc) = where_clause { if let Some(wc) = where_clause {
let colidx = *column_map.get(wc.0).unwrap(); let colidx = *column_map.get(wc.0).unwrap();
let row_pol = String::from_utf8_lossy(&row[colidx]); let row_pol = row[colidx].read_string();
if row_pol != wc.1 { if row_pol != wc.1 {
continue; continue;
} }
} }
for column in columns.iter() { for &column in columns.iter() {
let cpos = *column_map.get(column).unwrap(); if column == "id" {
let s = String::from_utf8_lossy(&row[cpos]); output.push_str(&rowid.to_string());
} else {
output.push_str(&s); let cpos = *column_map.get(column).unwrap();
output.push_str(&row[cpos].to_string());
}
output.push('|'); output.push('|');
} }
@ -186,7 +235,7 @@ struct DBHeader {
fn read_db_header(database: &[u8]) -> Result<DBHeader, Error> { fn read_db_header(database: &[u8]) -> Result<DBHeader, Error> {
let db_page_size = u16::from_be_bytes([database[16], database[17]]); let db_page_size = u16::from_be_bytes([database[16], database[17]]);
// Parse page header from database // Parse page header from database
let page_header = PageHeader::parse(&database[100..108])?; let (_, page_header) = PageHeader::parse(&database[100..108])?;
// Obtain all cell pointers // Obtain all cell pointers
let cell_pointers = database[108..] let cell_pointers = database[108..]
@ -223,7 +272,7 @@ fn count_rows_in_table(query: &str, db_header: DBHeader, database: &[u8]) -> Res
.unwrap(); .unwrap();
let table_page_offset = db_header.page_size as usize * (schema.root_page as usize - 1); let table_page_offset = db_header.page_size as usize * (schema.root_page as usize - 1);
let page_header = let (_, page_header) =
PageHeader::parse(&database[table_page_offset..table_page_offset + 8]).unwrap(); PageHeader::parse(&database[table_page_offset..table_page_offset + 8]).unwrap();
println!("{}", page_header.number_of_cells); println!("{}", page_header.number_of_cells);

View File

@ -1,9 +1,11 @@
use std::fmt::Display;
use crate::varint::parse_varint; use crate::varint::parse_varint;
use anyhow::{bail, Result}; use anyhow::{bail, Result};
/// Reads SQLite's "Record Format" as mentioned here: /// Reads SQLite's "Record Format" as mentioned here:
/// [record_format](https://www.sqlite.org/fileformat.html#record_format) /// [record_format](https://www.sqlite.org/fileformat.html#record_format)
pub fn parse_record(stream: &[u8], column_count: usize) -> Result<Vec<Vec<u8>>> { pub fn parse_record(stream: &[u8], column_count: usize) -> Result<Vec<ColumnValue>> {
// Parse number of bytes in header, and use bytes_read as offset // Parse number of bytes in header, and use bytes_read as offset
let (_, mut offset) = parse_varint(stream); let (_, mut offset) = parse_varint(stream);
@ -19,28 +21,127 @@ pub fn parse_record(stream: &[u8], column_count: usize) -> Result<Vec<Vec<u8>>>
let mut record = vec![]; let mut record = vec![];
for serial_type in serial_types { for serial_type in serial_types {
let column = parse_column_value(&stream[offset..], serial_type as usize)?; let column = parse_column_value(&stream[offset..], serial_type as usize)?;
offset += column.len(); offset += column.length();
record.push(column); record.push(column);
} }
Ok(record) Ok(record)
} }
fn parse_column_value(stream: &[u8], serial_type: usize) -> Result<Vec<u8>> { #[derive(Debug)]
let column_value = match serial_type { pub enum ColumnValue<'a> {
0 => vec![], Null,
U8(u8),
U16(u16),
U24(u32),
U32(u32),
U48(u64),
U64(u64),
FP64(f64),
False,
True,
Blob(&'a [u8]),
Text(String),
}
impl<'a> ColumnValue<'a> {
pub fn length(&self) -> usize {
match self {
ColumnValue::Null => 0,
ColumnValue::U8(_) => 1,
ColumnValue::U16(_) => 2,
ColumnValue::U24(_) => 3,
ColumnValue::U32(_) => 4,
ColumnValue::U48(_) => 6,
ColumnValue::U64(_) => 8,
ColumnValue::FP64(_) => 8,
ColumnValue::False => 0,
ColumnValue::True => 0,
ColumnValue::Blob(v) => v.len(),
ColumnValue::Text(v) => v.len(),
}
}
pub fn read_string(&self) -> String {
match self {
ColumnValue::Text(v) => v.clone(),
ColumnValue::Null => String::new(),
_ => unreachable!(),
}
}
pub fn read_u8(&self) -> u8 {
if let ColumnValue::U8(v) = self {
*v
} else {
unreachable!()
}
}
}
impl<'a> Display for ColumnValue<'a> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ColumnValue::Null => f.write_str(""),
ColumnValue::U8(v) => f.write_str(&v.to_string()),
ColumnValue::U16(v) => f.write_str(&v.to_string()),
ColumnValue::U24(v) => f.write_str(&v.to_string()),
ColumnValue::U32(v) => f.write_str(&v.to_string()),
ColumnValue::U48(v) => f.write_str(&v.to_string()),
ColumnValue::U64(v) => f.write_str(&v.to_string()),
ColumnValue::FP64(v) => f.write_str(&v.to_string()),
ColumnValue::False => f.write_str("false"),
ColumnValue::True => f.write_str("true"),
ColumnValue::Blob(v) => f.write_fmt(format_args!("{:?}", v)),
ColumnValue::Text(v) => f.write_str(v),
}
}
}
fn parse_column_value(stream: &[u8], serial_type: usize) -> Result<ColumnValue> {
Ok(match serial_type {
0 => ColumnValue::Null,
// 8 bit twos-complement integer // 8 bit twos-complement integer
1 => vec![stream[0]], 1 => ColumnValue::U8(stream[0]),
2 => {
let value = (!(stream[0] as u16) << 8) + !stream[1] as u16 + 1;
ColumnValue::U16(value)
}
3 => {
let value =
(!(stream[0] as u32) << 16) + (!(stream[1] as u32) << 8) + !stream[2] as u32 + 1;
ColumnValue::U24(value)
}
4 => {
let value = (!(stream[0] as u32) << 24)
+ (!(stream[0] as u32) << 16)
+ (!(stream[1] as u32) << 8)
+ !stream[2] as u32
+ 1;
ColumnValue::U32(value)
}
8 => ColumnValue::False,
9 => ColumnValue::True,
// Text encoding // Text encoding
n if serial_type >= 12 && serial_type % 2 == 0 => { n if serial_type >= 12 && serial_type % 2 == 0 => {
let n_bytes = (n - 12) / 2; let n_bytes = (n - 12) / 2;
stream[0..n_bytes as usize].to_vec() ColumnValue::Blob(&stream[0..n_bytes as usize])
} }
n if serial_type >= 13 && serial_type % 2 == 1 => { n if serial_type >= 13 && serial_type % 2 == 1 => {
let n_bytes = (n - 13) / 2; let n_bytes = (n - 13) / 2;
stream[0..n_bytes as usize].to_vec() let a = &stream[0..n_bytes as usize];
let s = String::from_utf8_lossy(a);
ColumnValue::Text(s.to_string())
} }
_ => bail!("Invalid serial_type: {}", serial_type), _ => bail!("Invalid serial_type: {}", serial_type),
}; })
Ok(column_value)
} }

View File

@ -1,3 +1,5 @@
use crate::record::ColumnValue;
#[derive(Debug)] #[derive(Debug)]
pub struct Schema { pub struct Schema {
pub kind: String, pub kind: String,
@ -9,20 +11,20 @@ pub struct Schema {
impl Schema { impl Schema {
/// Parses a record into a schema /// Parses a record into a schema
pub fn parse(record: Vec<Vec<u8>>) -> Option<Self> { pub fn parse(record: Vec<ColumnValue>) -> Option<Self> {
let mut items = record.into_iter(); let mut items = record.into_iter();
let kind = items.next()?; let kind = items.next()?.read_string();
let name = items.next()?; let name = items.next()?.read_string();
let table_name = items.next()?; let table_name = items.next()?.read_string();
let root_page = *items.next()?.get(0)?; let root_page = items.next()?.read_u8();
let sql = items.next()?; let sql = items.next()?.read_string();
let schema = Self { let schema = Self {
kind: String::from_utf8_lossy(&kind).to_string(), kind,
name: String::from_utf8_lossy(&name).to_string(), name,
table_name: String::from_utf8_lossy(&table_name).to_string(), table_name,
root_page, root_page,
sql: String::from_utf8_lossy(&sql).to_string(), sql,
}; };
Some(schema) Some(schema)
} }