From 1ad59279d59d1ebd65359c1b20b4d149421ae110 Mon Sep 17 00:00:00 2001 From: Ishan Jain Date: Mon, 11 Dec 2023 01:38:01 +0530 Subject: [PATCH] Refactored project a bit, added question parser --- src/header.rs | 64 ++++++++++++++++ src/main.rs | 195 +++++++++--------------------------------------- src/parser.rs | 0 src/qname.rs | 61 +++++++++++++++ src/question.rs | 37 +++++++++ src/rrecord.rs | 76 +++++++++++++++++++ 6 files changed, 273 insertions(+), 160 deletions(-) create mode 100644 src/header.rs create mode 100644 src/parser.rs create mode 100644 src/qname.rs create mode 100644 src/question.rs create mode 100644 src/rrecord.rs diff --git a/src/header.rs b/src/header.rs new file mode 100644 index 0000000..4a3faf7 --- /dev/null +++ b/src/header.rs @@ -0,0 +1,64 @@ +#[derive(Debug)] +pub struct Header { + pub ident: u16, + pub query: bool, + pub opcode: u8, // TODO: enum + pub authoritative: bool, + pub truncated: bool, + pub recursion_desired: bool, + pub recursion_avail: bool, + pub reserved: u8, + pub rcode: u8, // TODO: enum + pub qd_count: u16, + pub an_count: u16, + pub authority_records: u16, + pub additional_records: u16, +} + +impl Header { + pub fn parse(data: &[u8]) -> Result { + if data.len() != 12 { + return Err("input bytes len is not equal to 12"); + } + + Ok(Self { + ident: u16::from_be_bytes([data[0], data[1]]), + query: ((data[2] >> 7) & 1) == 1, + opcode: (data[2] >> 3), + authoritative: ((data[2] >> 2) & 1) == 1, + truncated: ((data[2] >> 1) & 1) == 1, + recursion_desired: (data[2] & 1) == 1, + recursion_avail: ((data[3] >> 7) & 1) == 1, + reserved: ((data[3] >> 4) & 0b111), + rcode: (data[3] & 0b1111), + qd_count: u16::from_be_bytes([data[4], data[5]]), + an_count: u16::from_be_bytes([data[6], data[7]]), + authority_records: u16::from_be_bytes([data[8], data[9]]), + additional_records: u16::from_be_bytes([data[10], data[11]]), + }) + } + + pub fn write_to(self, buf: &mut Vec) { + buf.reserve(12); + + // write ident + buf.extend(self.ident.to_be_bytes()); + + // Write flags + let flag0_byte = (self.query as u8) << 7 + | self.opcode << 3 + | (self.authoritative as u8) << 2 + | (self.truncated as u8) << 1 + | self.recursion_desired as u8; + let flag1_byte = (self.recursion_avail as u8) << 7 | self.reserved << 4 | self.rcode; + + buf.push(flag0_byte); + buf.push(flag1_byte); + + // Write counts + buf.extend(self.qd_count.to_be_bytes()); + buf.extend(self.an_count.to_be_bytes()); + buf.extend(self.authority_records.to_be_bytes()); + buf.extend(self.additional_records.to_be_bytes()); + } +} diff --git a/src/main.rs b/src/main.rs index 0aa2134..d450916 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,3 +1,11 @@ +mod header; +mod qname; +use header::*; +mod question; +use question::*; +mod rrecord; +use rrecord::*; + use std::net::UdpSocket; fn main() { @@ -13,6 +21,8 @@ fn main() { let mut recv_packet = Packet::parse(received_data).unwrap(); + println!("{:?}", recv_packet); + recv_packet.header.query = true; recv_packet.header.authoritative = false; recv_packet.header.truncated = false; @@ -24,19 +34,13 @@ fn main() { recv_packet.header.authority_records = 0; recv_packet.header.additional_records = 0; - recv_packet.questions = vec![Question { - name: "codecrafters.io".into(), - q_type: 1, - class: 1, - }]; - recv_packet.answers = vec![RRecord { - name: "codecrafters.io".into(), + name: recv_packet.questions[0].name.clone(), class: 1, r_type: 1, ttl: 1337, rdlength: 4, - data: RData::A([0x8, 0x8, 0x8, 0x8]), + rdata: RData::A([0x8, 0x8, 0x8, 0x8]), }]; recv_packet.write_to(&mut response); @@ -53,165 +57,36 @@ fn main() { } } -struct Header { - ident: u16, - query: bool, - opcode: u8, // TODO: enum - authoritative: bool, - truncated: bool, - recursion_desired: bool, - recursion_avail: bool, - reserved: u8, - rcode: u8, // TODO: enum - qd_count: u16, - an_count: u16, - authority_records: u16, - additional_records: u16, -} - -impl Header { - pub fn parse(data: &[u8]) -> Result { - if data.len() != 12 { - return Err("input bytes len is not equal to 12"); - } - - Ok(Self { - ident: u16::from_be_bytes([data[0], data[1]]), - query: ((data[2] >> 7) & 1) == 1, - opcode: (data[2] >> 3), - authoritative: ((data[2] >> 2) & 1) == 1, - truncated: ((data[2] >> 1) & 1) == 1, - recursion_desired: (data[2] & 1) == 1, - recursion_avail: ((data[3] >> 7) & 1) == 1, - reserved: ((data[3] >> 4) & 0b111), - rcode: (data[3] & 0b1111), - qd_count: u16::from_be_bytes([data[4], data[5]]), - an_count: u16::from_be_bytes([data[6], data[7]]), - authority_records: u16::from_be_bytes([data[8], data[9]]), - additional_records: u16::from_be_bytes([data[10], data[11]]), - }) - } - - pub fn write_to(self, buf: &mut Vec) { - // write ident - buf.extend(self.ident.to_be_bytes()); - - // Write flags - let flag0_byte = (self.query as u8) << 7 - | self.opcode << 3 - | (self.authoritative as u8) << 2 - | (self.truncated as u8) << 1 - | self.recursion_desired as u8; - let flag1_byte = (self.recursion_avail as u8) << 7 | self.reserved << 4 | self.rcode; - - buf.push(flag0_byte); - buf.push(flag1_byte); - - // Write counts - buf.extend(self.qd_count.to_be_bytes()); - buf.extend(self.an_count.to_be_bytes()); - buf.extend(self.authority_records.to_be_bytes()); - buf.extend(self.additional_records.to_be_bytes()); - } -} -struct Question { - name: Qname, - q_type: u16, - class: u16, -} - -impl Question { - pub fn parse(data: &[u8]) -> Result { - Err("invalid") - } - - pub fn write_to(self, buf: &mut Vec) { - self.name.write_to(buf); - buf.extend(self.q_type.to_be_bytes()); - buf.extend(self.class.to_be_bytes()); - } -} - -struct Qname(Vec<(u8, String)>); - -impl Qname { - pub fn write_to(&self, buf: &mut Vec) { - for (i, v) in &self.0 { - buf.push(*i); - buf.extend(v.bytes()); - } - - buf.push(0); - } -} - -impl From<&str> for Qname { - fn from(value: &str) -> Self { - let mut output = vec![]; - - for label in value.split('.') { - output.push((label.len() as u8, label.to_string())); - } - - Qname(output) - } -} - -struct RRecord { - name: Qname, - r_type: u16, - class: u16, - ttl: u32, - rdlength: u16, - data: RData, -} - -impl RRecord { - pub fn parse(data: &[u8]) -> Result { - todo!() - } - - pub fn write_to(self, buf: &mut Vec) { - self.name.write_to(buf); - buf.extend(self.r_type.to_be_bytes()); - buf.extend(self.class.to_be_bytes()); - buf.extend(self.ttl.to_be_bytes()); - buf.extend(self.rdlength.to_be_bytes()); - - self.data.write_to(buf); - } -} - -enum RData { - A([u8; 4]), - Aaaa([u8; 16]), -} - -impl RData { - pub fn write_to(self, buf: &mut Vec) { - match self { - RData::A(addr) => buf.extend(addr), - RData::Aaaa(addr) => buf.extend(addr), - } - } -} - -struct Packet { +#[derive(Debug)] +struct Packet<'a> { header: Header, - questions: Vec, - answers: Vec, + questions: Vec>, + answers: Vec>, } -impl Packet { - pub fn parse(data: &[u8]) -> Result { +impl<'a> Packet<'a> { + pub fn parse(mut data: &'a [u8]) -> Result { let header = Header::parse(&data[..12])?; - // let questions = vec![Question::parse(&data[12..])?]; // TODO: need some thing better here - // let answers = vec![RRecord::parse(&data[12..])?]; // TODO: need some thing better here + data = &data[12..]; + + let mut questions = Vec::with_capacity(header.qd_count as usize); + + for _ in 0..header.qd_count { + let rec = Question::parse(data)?; + data = &data[rec.length()..]; + questions.push(rec); + } + let mut answers = Vec::with_capacity(header.an_count as usize); + for _ in 0..header.an_count { + let rec = RRecord::parse(data)?; + data = &data[rec.length()..]; + answers.push(rec); + } Ok(Self { header, - questions: vec![], - answers: vec![], + questions, + answers, }) } diff --git a/src/parser.rs b/src/parser.rs new file mode 100644 index 0000000..e69de29 diff --git a/src/qname.rs b/src/qname.rs new file mode 100644 index 0000000..caf9355 --- /dev/null +++ b/src/qname.rs @@ -0,0 +1,61 @@ +use std::{ + borrow::Cow, + fmt::{Display, Formatter, Result as FmtResult}, +}; + +#[derive(Debug, Clone)] +pub struct Qname<'a> { + labels: Cow<'a, [u8]>, +} + +impl<'a> Qname<'a> { + pub fn write_to(&self, buf: &mut Vec) { + buf.extend(self.labels.iter()); + } + + // Assume data _starts_ at the question section + pub fn parse(data: &'a [u8]) -> Self { + // Do a quick check to make sure every thing is valid + // and save the slice + + let mut i = 0; + while data[i] != 0 { + let length = data[i] as usize; + i += length + 1; + } + // NULL byte + i += 1; + + Self { + labels: Cow::Borrowed(&data[0..i]), + } + } + + pub fn length(&self) -> usize { + self.labels.len() + } +} + +impl<'a> From<&'a str> for Qname<'a> { + fn from(value: &'a str) -> Self { + let mut labels = vec![]; + + for label in value.split('.') { + labels.push(label.len() as u8); + labels.extend(label.bytes()); + } + + // for null byte + labels.push(0); + + Qname { + labels: Cow::Owned(labels), + } + } +} + +impl<'a> Display for Qname<'a> { + fn fmt(&self, f: &mut Formatter) -> FmtResult { + todo!() + } +} diff --git a/src/question.rs b/src/question.rs new file mode 100644 index 0000000..31db36b --- /dev/null +++ b/src/question.rs @@ -0,0 +1,37 @@ +use crate::qname::Qname; + +#[derive(Debug)] +pub struct Question<'a> { + pub name: Qname<'a>, + pub q_type: u16, + pub class: u16, +} + +impl<'a> Question<'a> { + pub fn parse(mut data: &'a [u8]) -> Result { + let qname = Qname::parse(data); + data = &data[qname.length()..]; + + debug_assert!(data.len() >= 4); + let q_type = u16::from_be_bytes([data[0], data[1]]); + let class = u16::from_be_bytes([data[2], data[3]]); + + Ok(Self { + name: qname, + q_type, + class, + }) + } + + pub fn length(&self) -> usize { + self.name.length() + 2 + 2 + } + + pub fn write_to(self, buf: &mut Vec) { + buf.reserve(self.length()); + + self.name.write_to(buf); + buf.extend(self.q_type.to_be_bytes()); + buf.extend(self.class.to_be_bytes()); + } +} diff --git a/src/rrecord.rs b/src/rrecord.rs new file mode 100644 index 0000000..56fcb8b --- /dev/null +++ b/src/rrecord.rs @@ -0,0 +1,76 @@ +use crate::qname::Qname; + +#[derive(Debug)] +pub struct RRecord<'a> { + pub name: Qname<'a>, + pub r_type: u16, + pub class: u16, + pub ttl: u32, + pub rdlength: u16, + pub rdata: RData, +} + +impl<'a> RRecord<'a> { + pub fn parse(mut data: &'a [u8]) -> Result { + let qname = Qname::parse(data); + data = &data[qname.length()..]; + debug_assert!(data.len() >= 10); + + let r_type = u16::from_be_bytes([data[0], data[1]]); + let class = u16::from_be_bytes([data[2], data[3]]); + let ttl = u32::from_be_bytes([data[4], data[5], data[6], data[7]]); + let rdlength = u16::from_be_bytes([data[8], data[9]]); + + data = &data[2 + 2 + 4 + 2..]; + + let rdata = Self::parse_rdata(r_type, data); + + Ok(Self { + name: qname, + r_type, + class, + ttl, + rdlength, + rdata, + }) + } + + fn parse_rdata(r_type: u16, data: &[u8]) -> RData { + match r_type { + 1 => RData::A([data[0], data[1], data[2], data[3]]), + + _ => unimplemented!(), + } + } + + pub fn write_to(self, buf: &mut Vec) { + buf.reserve(self.length()); + + self.name.write_to(buf); + buf.extend(self.r_type.to_be_bytes()); + buf.extend(self.class.to_be_bytes()); + buf.extend(self.ttl.to_be_bytes()); + buf.extend(self.rdlength.to_be_bytes()); + + self.rdata.write_to(buf); + } + + pub fn length(&self) -> usize { + self.name.length() + 2 + 2 + 4 + 2 + self.rdlength as usize + } +} + +#[derive(Debug)] +pub enum RData { + A([u8; 4]), + Aaaa([u8; 16]), +} + +impl RData { + pub fn write_to(self, buf: &mut Vec) { + match self { + RData::A(addr) => buf.extend(addr), + RData::Aaaa(addr) => buf.extend(addr), + } + } +}