diff --git a/src/main.rs b/src/main.rs index 8f38d0c..25116a3 100644 --- a/src/main.rs +++ b/src/main.rs @@ -18,6 +18,16 @@ const QUERY_REGEX: Lazy = Lazy::new(|| { .expect("error in compiling regex") }); +const WHERE_REGEX: Lazy = Lazy::new(|| { + let regex = + "select ([a-zA-Z0-9*].*) FROM ([a-zA-Z0-9].*) WHERE ([a-zA-Z0-9].*) = ([a-zA-Z0-9'].*)"; + + RegexBuilder::new(regex) + .case_insensitive(true) + .build() + .expect("error in compiling regex") +}); + fn main() -> Result<()> { // Parse arguments let args = std::env::args().collect::>(); @@ -110,7 +120,7 @@ fn main() -> Result<()> { } fn read_columns(query: &str, db_header: DBHeader, database: &[u8]) -> Result<(), Error> { - let (columns, table) = read_column_and_table(query); + let (columns, table, where_clause) = read_column_and_table(query); // Assume it's valid SQL let schema = db_header @@ -141,6 +151,16 @@ fn read_columns(query: &str, db_header: DBHeader, database: &[u8]) -> Result<(), for row in rows { let mut output = String::new(); + if let Some(wc) = where_clause { + let colidx = *column_map.get(wc.0).unwrap(); + + let row_pol = String::from_utf8_lossy(&row[colidx]); + + if row_pol != wc.1 { + continue; + } + } + for column in columns.iter() { let cpos = *column_map.get(column).unwrap(); let s = String::from_utf8_lossy(&row[cpos]); @@ -193,7 +213,7 @@ fn read_db_header(database: &[u8]) -> Result { } fn count_rows_in_table(query: &str, db_header: DBHeader, database: &[u8]) -> Result<(), Error> { - let (_, table) = read_column_and_table(query); + let (_, table, _) = read_column_and_table(query); // Assume it's valid SQL let schema = db_header @@ -226,11 +246,31 @@ fn find_column_positions(schema: &str) -> HashMap<&str, usize> { .collect() } -fn read_column_and_table(query: &str) -> (Vec<&str>, &str) { +fn read_column_and_table(query: &str) -> (Vec<&str>, &str, Option<(&str, &str)>) { + if let Some(matches) = WHERE_REGEX.captures(query) { + let parameter = matches.get(3).unwrap().as_str().trim(); + let value = matches.get(4).unwrap().as_str().trim(); + let columns = matches.get(1).unwrap().as_str(); + let table = matches.get(2).unwrap().as_str(); + let table: &str = table.trim_matches(|c: char| !c.is_alphabetic()); + let column = columns + .split(',') + .filter(|c| !c.is_empty()) + .map(|c| c.trim()) + .collect(); + + return ( + column, + table, + Some((parameter, value.trim_matches(|c| c == '\''))), + ); + } + let matches = QUERY_REGEX.captures(query).unwrap(); let columns = matches.get(1).unwrap().as_str(); let table = matches.get(2).unwrap().as_str(); + let table: &str = table.trim_matches(|c: char| !c.is_alphabetic()); let column = columns .split(',') @@ -238,5 +278,5 @@ fn read_column_and_table(query: &str) -> (Vec<&str>, &str) { .map(|c| c.trim()) .collect(); - (column, table) + (column, table, None) }