diff --git a/src/routes/mod.rs b/src/routes/mod.rs index d688a9c..32cf85e 100644 --- a/src/routes/mod.rs +++ b/src/routes/mod.rs @@ -14,7 +14,7 @@ use warp::http::header::{ use warp::http::{Response, StatusCode}; use warp::{path, Filter, Rejection, Reply}; -pub fn routes(pool: &'static PgPool) -> impl Filter { +pub fn routes(pool: &'static PgPool) -> impl Filter { let pool = warp::any().map(move || pool); let index = warp::path::end() .and(warp::get2()) @@ -100,23 +100,58 @@ fn not_found(rejection: Rejection) -> Result { mod test { use super::routes; use crate::PgPool; - use diesel::r2d2::{ConnectionManager, Pool}; + use diesel::r2d2::{ConnectionManager, CustomizeConnection, Pool}; + use diesel::Connection; use lazy_static::lazy_static; use scraper::{Html, Selector}; use serde::Deserialize; use std::env; use std::str; + use warp::filters::BoxedFilter; + use warp::http::header::{CONTENT_LENGTH, LOCATION}; + use warp::reply::{Reply, Response}; + use warp::Filter; lazy_static! { static ref POOL: PgPool = { - let pool = Pool::new(ConnectionManager::new( - env::var("DATABASE_URL") - .expect("Setting DATABASE_URL environment variable required to run tests"), - )) - .expect("Couldn't create a connection pool"); + let pool = Pool::builder() + .connection_customizer(Box::new(ExecuteWithinTransaction)) + .max_size(1) + .build(ConnectionManager::new(env::var("DATABASE_URL").expect( + "Setting DATABASE_URL environment variable required to run tests", + ))) + .expect("Couldn't create a connection pool"); diesel_migrations::run_pending_migrations(&pool.get().unwrap()).unwrap(); pool }; + static ref ROUTES: BoxedFilter<(Response,)> = + routes(&POOL).map(Reply::into_response).boxed(); + } + + #[derive(Debug)] + struct ExecuteWithinTransaction; + + impl CustomizeConnection for ExecuteWithinTransaction + where + C: Connection, + { + fn on_acquire(&self, conn: &mut C) -> Result<(), E> { + conn.begin_test_transaction().unwrap(); + Ok(()) + } + } + + fn get_html_id() -> String { + let response = warp::test::request().reply(&*ROUTES); + let document = Html::parse_document(str::from_utf8(response.body()).unwrap()); + document + .select(&Selector::parse("#language option").unwrap()) + .find(|element| element.text().next() == Some("HTML")) + .expect("a language called HTML to exist") + .value() + .attr("value") + .expect("an ID") + .to_string() } #[test] @@ -127,19 +162,9 @@ mod test { mode: &'a str, mime: &'a str, } - let routes = routes(&POOL); - let response = warp::test::request().reply(&routes); - let document = Html::parse_document(str::from_utf8(response.body()).unwrap()); - let id = document - .select(&Selector::parse("#language option").unwrap()) - .find(|element| element.text().next() == Some("HTML")) - .expect("a language called HTML to exist") - .value() - .attr("value") - .expect("an ID"); let response = warp::test::request() - .path(&format!("/api/v0/language/{}", id)) - .reply(&routes); + .path(&format!("/api/v0/language/{}", get_html_id())) + .reply(&*ROUTES); assert_eq!( serde_json::from_slice::(response.body()).unwrap(), ApiLanguage { @@ -148,4 +173,22 @@ mod test { }, ); } + + #[test] + fn test_raw_pastes() { + let body = format!("language={}&code=abc", get_html_id()); + let reply = warp::test::request() + .method("POST") + .header(CONTENT_LENGTH, body.len()) + .body(body) + .reply(&*ROUTES); + let location = reply.headers()[LOCATION].to_str().unwrap(); + assert_eq!( + warp::test::request() + .path(&format!("{}.txt", location)) + .reply(&*ROUTES) + .body(), + "abc" + ); + } }