Pass database reference through a FromRequest - rust

I'm using the Rocket web framework in Rust and I have a database pool (DbConn).
#[database("redirect-api")]
pub struct DbConn(diesel::PgConnection);
I have been able to pass DbConn to each route but I am having trouble when trying to pass it to a function that is called by from_request. In get_redirects I use the ApiKey struct, and how I understand it, FromRequest is a guard that checks the input. In from_request, I call the is_valid function but can't pass it the DbConn, like I could if I was calling is_valid from a route, like directly from get_redirects.
I want to be able to use DbConn in the is_valid function, but I get this error if I try to use &DbConn
the trait `diesel::Connection` is not implemented for `fn(PooledConnection<<diesel::PgConnection as Poolable>::Manager>) -> DbConn {DbConn}`
|
43 | .load::<Token>(&DbConn)
| ^^^^^^^ the trait `diesel::Connection` is not implemented for `fn(PooledConnection<<diesel::PgConnection as Poolable>::Manager>) -> DbConn {DbConn}`
|
= note: required because of the requirements on the impl of `LoadQuery<fn(PooledConnection<<diesel::PgConnection as Poolable>::Manager>) -> DbConn {DbConn}, models::Token>` for `diesel::query_builder::SelectStatement<schema::tokens::table, query_builder::select_clause::DefaultSelectClause, query_builder::distinct_clause::NoDistinctClause, query_builder::where_clause::WhereClause<diesel::expression::operators::Eq<schema::tokens::columns::token, diesel::expression::bound::Bound<diesel::sql_types::Text, std::string::String>>>, query_builder::order_clause::NoOrderClause, query_builder::limit_clause::LimitClause<diesel::expression::bound::Bound<BigInt, i64>>>`
#![feature(decl_macro)]
use rocket::{catchers, routes};
#[macro_use]
extern crate diesel;
#[macro_use]
extern crate rocket_contrib;
use diesel::prelude::*;
// --snip--
#[database("redirect-api")]
pub struct DbConn(diesel::PgConnection);
// --snip--
fn is_valid(key: &str) {
// ... use key, hash is, etc
// check if hash is in database
// this function needs the database pool or DbConn
// but it can't be passed through FromRequest
}
impl<'a, 'r> FromRequest<'a, 'r> for ApiKey {
type Error = ApiKeyError;
fn from_request(request: &'a Request<'r>) -> request::Outcome<Self, Self::Error> {
let keys: Vec<_> = request.headers().get("x-api-key").collect();
match keys.len() {
0 => Outcome::Failure((Status::BadRequest, ApiKeyError::Missing)),
1 if is_valid(keys[0]) => Outcome::Success(ApiKey(keys[0].to_string())),
1 => Outcome::Failure((Status::BadRequest, ApiKeyError::Invalid)),
_ => Outcome::Failure((Status::BadRequest, ApiKeyError::BadCount)),
}
}
}
// --snip--
#[get("/redirects")]
pub fn get_redirects(_key: ApiKey, conn: DbConn) -> String {
// ... do database things with conn
format!("something {}", some_value_from_db)
}
// --snip--
fn main() {
rocket::ignite()
.register(catchers![not_found])
.attach(DbConn::fairing())
.mount("/", routes![root, get_redirect])
.launch();
}

Related

Actix and Inventory crate

I'm trying to make it possible to register an actix route automatically. To do so I found the crate Inventory which seems to answer my need.
I have the following code:
#[get("/hello2/{name}")]
async fn greet2(req: HttpRequest) -> String {
let name = req.match_info().get("name").unwrap_or("World");
let a = format!("Hello2 {}!", &name);
a
}
pub struct TestStruct {
test: fn(HttpRequest) -> dyn Future<Output = String>
}
inventory::collect!(TestStruct);
inventory::submit! {
let mut a = TestStruct {
test: <greet2 as HttpServiceFactory>::register
};
a.test.push(greet2::register::greet2);
a
}
#[actix_web::main]
async fn main() -> std::io::Result<()> {
HttpServer::new(move || {
let mut app = App::new();
for route in inventory::iter::<AutomaticHttpService> {
app = app.service(route);
}
app
})
.bind(("127.0.0.1", 8080))?
.workers(2)
.run()
.await
}
My code does not compile as I cannot access the requested function pointer and I cannot use the HttpServiceFactory trait as it is not declared in the same file. I also tried to use the std::any::Any type and then cast to HttpServiceFactory (maybe using unsafe) but without success because its size was not known at compile time.
Would you have any clues to unblock my situation?

How can I get an Axum Handler function to return a Vec?

I'm learning how to use Axum with SQLx starting with this example. The basic example works, but I have problem trying to move forward. I am working with a simple database table as shown below:
todo | description
--------+--------------
todo_1 | doing todo 1
todo_2 | doing todo 2
todo_3 | doing todo 3
I am trying to simply get back "SELECT * FROM todos", but I am getting an error. I think I am getting the return of the Result type wrong but I am not sure what to do next. The entirety of main.rs is shown below.
//! Example of application using <https://github.com/launchbadge/sqlx>
//!
//! Run with
//!
//! ```not_rust
//! cd examples && cargo run -p example-sqlx-postgres
//! ```
//!
//! Test with curl:
//!
//! ```not_rust
//! curl 127.0.0.1:3000
//! curl -X POST 127.0.0.1:3000
//! ```
use axum::{
async_trait,
extract::{Extension, FromRequest, RequestParts},
http::StatusCode,
routing::get,
Router,
};
use sqlx::postgres::{PgPool, PgPoolOptions, PgRow};
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
use std::{net::SocketAddr, time::Duration};
#[tokio::main]
async fn main() {
tracing_subscriber::registry()
.with(tracing_subscriber::EnvFilter::new(
std::env::var("RUST_LOG").unwrap_or_else(|_| "example_tokio_postgres=debug".into()),
))
.with(tracing_subscriber::fmt::layer())
.init();
let db_connection_str = std::env::var("DATABASE_URL")
.unwrap_or_else(|_| "postgres://postgres:postgres#localhost".to_string());
// setup connection pool
let pool = PgPoolOptions::new()
.max_connections(5)
.connect_timeout(Duration::from_secs(3))
.connect(&db_connection_str)
.await
.expect("can connect to database");
// build our application with some routes
let app = Router::new()
.route(
"/",
get(using_connection_pool_extractor).post(using_connection_extractor),
)
.layer(Extension(pool));
// run it with hyper
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
tracing::debug!("listening on {}", addr);
axum::Server::bind(&addr)
.serve(app.into_make_service())
.await
.unwrap();
}
// we can extract the connection pool with `Extension`
async fn using_connection_pool_extractor(
Extension(pool): Extension<PgPool>,
) -> Result<Vec<String>, (StatusCode, String)> {
sqlx::query_scalar("select * from todos")
.fetch_one(&pool)
.await
.map_err(internal_error)
}
// we can also write a custom extractor that grabs a connection from the pool
// which setup is appropriate depends on your application
struct DatabaseConnection(sqlx::pool::PoolConnection<sqlx::Postgres>);
#[async_trait]
impl<B> FromRequest<B> for DatabaseConnection
where
B: Send,
{
type Rejection = (StatusCode, String);
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
let Extension(pool) = Extension::<PgPool>::from_request(req)
.await
.map_err(internal_error)?;
let conn = pool.acquire().await.map_err(internal_error)?;
Ok(Self(conn))
}
}
async fn using_connection_extractor(
DatabaseConnection(conn): DatabaseConnection,
) -> Result<String, (StatusCode, String)> {
let mut conn = conn;
sqlx::query_scalar("select 'hello world from pg'")
.fetch_one(&mut conn)
.await
.map_err(internal_error)
}
/// Utility function for mapping any error into a `500 Internal Server Error`
/// response.
fn internal_error<E>(err: E) -> (StatusCode, String)
where
E: std::error::Error,
{
(StatusCode::INTERNAL_SERVER_ERROR, err.to_string())
}
Compared to the example, I changed this function so that it returns a Vec<String> instead of a plain String, but I get a compiler error:
async fn using_connection_pool_extractor(
Extension(pool): Extension<PgPool>,
) -> Result<Vec<String>, (StatusCode, String)> {
sqlx::query_scalar("select * from todos")
.fetch_one(&pool)
.await
.map_err(internal_error)
}
error[E0277]: the trait bound `fn(Extension<Pool<sqlx::Postgres>>) -> impl Future<Output = Result<Vec<String>, (StatusCode, String)>> {using_connection_pool_extractor}: Handler<_, _>` is not satisfied
--> src/main.rs:52:17
|
52 | get(using_connection_pool_extractor).post(using_connection_extractor),
| --- ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ the trait `Handler<_, _>` is not implemented for `fn(Extension<Pool<sqlx::Postgres>>) -> impl Future<Output = Result<Vec<String>, (StatusCode, String)>> {using_connection_pool_extractor}`
| |
| required by a bound introduced by this call
|
= help: the trait `Handler<T, ReqBody>` is implemented for `axum::handler::Layered<S, T>`
note: required by a bound in `axum::routing::get`
I am not sure what this error is suggesting or if it is even related to the actual problem.
Try using axum::Json:
async fn using_connection_pool_extractor(
Extension(pool): Extension<PgPool>,
) -> Result<axum::Json<Vec<String>>, (StatusCode, String)> {
sqlx::query_scalar("select * from todos")
.fetch_one(&pool)
.await
.map(|todos| axum::Json(todos))
.map_err(internal_error)
}
The reason why is that there's no implementation of the IntoResponse trait for Vec<T>. Here's a longer answer by Axum's author: https://users.rust-lang.org/t/axum-error-handling-trait-question/65530

How do I use PickleDB with Rocket/Juniper Context?

I'm trying to write a Rocket / Juniper / Rust based GraphQL Server using PickleDB - an in-memory key/value store.
The pickle db is created / loaded at the start and given to rocket to manage:
fn rocket() -> Rocket {
let pickle_path = var_os(String::from("PICKLE_PATH")).unwrap_or(OsString::from("pickle.db"));
let pickle_db_dump_policy = PickleDbDumpPolicy::PeriodicDump(Duration::from_secs(120));
let pickle_serialization_method = SerializationMethod::Bin;
let pickle_db: PickleDb = match Path::new(&pickle_path).exists() {
false => PickleDb::new(pickle_path, pickle_db_dump_policy, pickle_serialization_method),
true => PickleDb::load(pickle_path, pickle_db_dump_policy, pickle_serialization_method).unwrap(),
};
rocket::ignite()
.manage(Schema::new(Query, Mutation))
.manage(pickle_db)
.mount(
"/",
routes![graphiql, get_graphql_handler, post_graphql_handler],
)
}
And I want to retrieve the PickleDb instance from the Rocket State in my Guard:
pub struct Context {
pickle_db: PickleDb,
}
impl juniper::Context for Context {}
impl<'a, 'r> FromRequest<'a, 'r> for Context {
type Error = ();
fn from_request(_request: &'a Request<'r>) -> request::Outcome<Context, ()> {
let pickle_db = _request.guard::<State<PickleDb>>()?.inner();
Outcome::Success(Context { pickle_db })
}
}
This does not work because the State only gives me a reference:
26 | Outcome::Success(Context { pickle_db })
| ^^^^^^^^^ expected struct `pickledb::pickledb::PickleDb`, found `&pickledb::pickledb::PickleDb`
When I change my Context struct to contain a reference I get lifetime issues which I'm not yet familiar with:
15 | pickle_db: &PickleDb,
| ^ expected named lifetime parameter
I tried using 'static which does make rust quite unhappy and I tried to use the request lifetime (?) 'r of the FromRequest, but that does not really work either...
How do I get this to work? As I'm quite new in rust, is this the right way to do things?
I finally have a solution, although the need for unsafe indicates it is sub-optimal :)
#![allow(unsafe_code)]
use pickledb::{PickleDb, PickleDbDumpPolicy, SerializationMethod};
use serde::de::DeserializeOwned;
use serde::Serialize;
use std::env;
use std::path::Path;
use std::time::Duration;
pub static mut PICKLE_DB: Option<PickleDb> = None;
pub fn cache_init() {
let pickle_path = env::var(String::from("PICKLE_PATH")).unwrap_or(String::from("pickle.db"));
let pickle_db_dump_policy = PickleDbDumpPolicy::PeriodicDump(Duration::from_secs(120));
let pickle_serialization_method = SerializationMethod::Json;
let pickle_db = match Path::new(&pickle_path).exists() {
false => PickleDb::new(
pickle_path,
pickle_db_dump_policy,
pickle_serialization_method,
),
true => PickleDb::load(
pickle_path,
pickle_db_dump_policy,
pickle_serialization_method,
)
.unwrap(),
};
unsafe {
PICKLE_DB = Some(pickle_db);
}
}
pub fn cache_get<V>(key: &str) -> Option<V>
where
V: DeserializeOwned + std::fmt::Debug,
{
unsafe {
let pickle_db = PICKLE_DB
.as_ref()
.expect("cache uninitialized - call cache_init()");
pickle_db.get::<V>(key)
}
}
pub fn cache_set<V>(key: &str, value: &V) -> Result<(), pickledb::error::Error>
where
V: Serialize,
{
unsafe {
let pickle_db = PICKLE_DB
.as_mut()
.expect("cache uninitialized - call cache_init()");
pickle_db.set::<V>(key, value)?;
Ok(())
}
}
This can be simply imported and used as expected, but I think I'll run into issues when the load gets to high...

Why need static variable delivery between method but need not in same one method?

I want to start a Hyper server in a function with port and dao parameters provided by main(), but the function only works after I explicitly indicate the 'static lifetime. This confused me a lot.
extern crate futures;
extern crate hyper;
use futures::future::Future;
use hyper::header::ContentLength;
use hyper::server::{Http, Request, Response, Service};
use std::net::SocketAddr;
trait Dao {}
struct MysqlDao;
impl Dao for MysqlDao {}
struct HelloWorld<'a> {
dao: &'a Dao,
}
const PHRASE: &'static str = "Hello, World!";
impl<'a> Service for HelloWorld<'a> {
type Request = Request;
type Response = Response;
type Error = hyper::Error;
type Future = Box<Future<Item = Self::Response, Error = Self::Error>>;
fn call(&self, _req: Request) -> Self::Future {
Box::new(futures::future::ok(
Response::new()
.with_header(ContentLength(PHRASE.len() as u64))
.with_body(PHRASE),
))
}
}
fn main() {
let addr = "127.0.0.1:3000".parse().unwrap();
let dao = MysqlDao;
let server = Http::new()
.bind(&addr, move || Ok(HelloWorld { dao: &dao }))
.unwrap();
server.run().unwrap();
}
The Http::new().bind API documententation said it needs a NewService + 'static, so I think the compiler would infer the dao variant is 'static, but when I move the last three statements out of main, it can't infer!
fn main() {
let addr = "127.0.0.1:3000".parse().unwrap();
let dao: MysqlDao = MysqlDao;
web_startup(&addr, &dao);
}
fn web_startup<T: Dao>(addr: &SocketAddr, dao: &T) {
let server = Http::new()
.bind(addr, move || Ok(HelloWorld { dao }))
.unwrap();
server.run().unwrap();
}
I get the error:
error[E0477]: the type `[closure#src/main.rs:44:21: 44:51 dao:&T]` does not fulfill the required lifetime
--> src/main.rs:44:10
|
44 | .bind(addr, move || Ok(HelloWorld { dao }))
| ^^^^
|
= note: type must satisfy the static lifetime
So I fixed it:
fn main() {
let addr = "127.0.0.1:3000".parse().unwrap();
static DAO: MysqlDao = MysqlDao;
web_startup(&addr, &DAO);
}
fn web_startup<T: Dao>(addr: &SocketAddr, dao: &'static T) {
let server = Http::new()
.bind(addr, move || Ok(HelloWorld { dao }))
.unwrap();
server.run().unwrap();
}
I don't understand why I should use the static keyword for static DAO: MysqlDao = MysqlDao; statement but need not before change the code. The compiler couldn't infer it or am I thinking about things incorrectly?
The reason the compiler cannot infer that the only time the web_startup function will be called it's called with a 'static is because that's not guaranteed. What if the function were public and it was called by a third-party module? The compiler would have to tell the end-user to use a 'static on a function that doesn't seem to require one. What if some time in the future eval() is added to Rust (e.g. for a REPL), so that even your private function could be called with unexpected function parameters?
You're asking for an inference that should not happen.

Function local variable doesn't live long enough [duplicate]

This question already has answers here:
Is there any way to return a reference to a variable created in a function?
(5 answers)
Closed 5 years ago.
I'm trying to write a wrapper around serde_json & Rocket's FromData to strongly type some of the JSON I exchange with a server.
I can't compile the following code:
extern crate serde_json;
extern crate rocket;
extern crate serde;
use serde::ser::Error;
use serde_json::Value;
use rocket::data::DataStream;
use rocket::outcome::IntoOutcome;
use std::io::Read;
static NULL: Value = serde_json::Value::Null;
pub struct ResponseJSON<'v> {
success: bool,
http_code: u16,
data: &'v serde_json::Value,
}
impl<'v> ResponseJSON<'v> {
pub fn ok() -> ResponseJSON<'v> {
ResponseJSON {
success: true,
http_code: 200,
data: &NULL,
}
}
pub fn http_code(mut self, code: u16) -> ResponseJSON<'v> {
self.http_code = code;
self
}
pub fn data(mut self, ref_data: &'v serde_json::Value) -> ResponseJSON<'v> {
self.data = ref_data;
self
}
pub fn from_serde_value(json: &'v serde_json::Value) -> ResponseJSON<'v> {
if !json["success"].is_null() {
ResponseJSON::ok()
.http_code(json["http_code"].as_u64().unwrap() as u16)
.data(json.get("data").unwrap_or(&NULL))
} else {
ResponseJSON::ok()
.data(json.pointer("").unwrap())
}
}
}
impl<'v> rocket::data::FromData for ResponseJSON<'v> {
type Error = serde_json::Error;
fn from_data(request: &rocket::Request, data: rocket::Data) -> rocket::data::Outcome<Self, serde_json::Error> {
if !request.content_type().map_or(false, |ct| ct.is_json()) {
println!("Content-Type is not JSON.");
return rocket::Outcome::Forward(data);
}
let data_from_reader = data.open().take(1<<20);
let value = serde_json::from_reader(data_from_reader);
let unwraped_value : Value = if value.is_ok() { value.unwrap() } else { Value::Null };
if !unwraped_value.is_null() {
Ok(ResponseJSON::from_serde_value(&unwraped_value)).into_outcome()
} else {
Err(serde_json::Error::custom("Unable to create JSON from reader")).into_outcome()
}
}
}
fn main() {
println!("it runs!");
}
The compiler's error:
Compiling tests v0.1.0 (file:///Users/bgbahoue/Projects.nosync/tests)
error: `unwraped_value` does not live long enough
--> src/main.rs:64:48
|
64 | Ok(ResponseJSON::from_serde_value(&unwraped_value)).into_outcome()
| ^^^^^^^^^^^^^^ does not live long enough
...
68 | }
| - borrowed value only lives until here
|
note: borrowed value must be valid for the lifetime 'v as defined on the body at 53:114...
--> src/main.rs:53:115
|
53 | fn from_data(request: &rocket::Request, data: rocket::Data) -> rocket::data::Outcome<Self, serde_json::Error> {
| ___________________________________________________________________________________________________________________^
54 | | if !request.content_type().map_or(false, |ct| ct.is_json()) {
55 | | println!("Content-Type is not JSON.");
56 | | return rocket::Outcome::Forward(data);
... |
67 | | }
68 | | }
| |_____^
error: aborting due to previous error
Since data_from_reader, value and unwraped_value come from data I thought the compiler could infer that it had the same lifetime but apparently it's not the case. Is there any way I could state that or do something that would work in such a case ?
serde_json::from_reader:
pub fn from_reader<R, T>(rdr: R) -> Result<T>
where
R: Read,
T: DeserializeOwned,
rocket::data::Data::open:
fn open(self) -> DataStream
rocket::data::DataStream::take:
fn take(self, limit: u64) -> Take<Self>
As per #LukasKalbertodt 's comment above, it does work when ResponseJSON owns the serde_json::Value
Revised code (pasted as is, even though there are better ways to chain some parts of the code)
#![allow(dead_code)]
extern crate serde_json;
extern crate rocket;
extern crate serde;
use serde::ser::Error;
use serde_json::Value;
use rocket::outcome::IntoOutcome;
use std::io::Read;
static NULL: Value = serde_json::Value::Null;
pub struct ResponseJSON { // <-- changed to remove the lifetime parameter
success: bool,
http_code: u16,
data: serde_json::Value, // <- changed to remove the '&'
}
impl ResponseJSON {
pub fn ok() -> ResponseJSON {
ResponseJSON {
success: true,
http_code: 200,
data: Value::Null,
}
}
pub fn http_code(mut self, code: u16) -> ResponseJSON {
self.http_code = code;
self
}
pub fn data(mut self, data: serde_json::Value) -> ResponseJSON { // <- changed to remove the '&'
self.data = data;
self
}
pub fn from_serde_value(json: serde_json::Value) -> ResponseJSON { // <- changed to remove the reference & lifetime parameter
if !json["success"].is_null() {
ResponseJSON::ok()
.http_code(json["http_code"].as_u64().unwrap() as u16)
.data(json.get("data").unwrap_or(&NULL).clone())
} else {
ResponseJSON::ok()
.data(json.pointer("").unwrap().clone())
}
}
}
impl rocket::data::FromData for ResponseJSON {
type Error = serde_json::Error;
fn from_data(request: &rocket::Request, data: rocket::Data) -> rocket::data::Outcome<Self, serde_json::Error> {
if !request.content_type().map_or(false, |ct| ct.is_json()) {
println!("Content-Type is not JSON.");
return rocket::Outcome::Forward(data);
}
let data_from_reader = data.open().take(1<<20);
let value = serde_json::from_reader(data_from_reader);
let unwraped_value : Value = if value.is_ok() { value.unwrap() } else { Value::Null };
if !unwraped_value.is_null() {
Ok(ResponseJSON::from_serde_value(unwraped_value)).into_outcome() // <- changed to remove the '&' in front of `unwraped_value`
} else {
Err(serde_json::Error::custom("Unable to create JSON from reader")).into_outcome()
}
}
}
fn main() {
println!("it compiles & runs");
}
cargo run output
Compiling tests v0.1.0 (file:///Users/bgbahoue/Projects.nosync/tests)
Finished dev [unoptimized + debuginfo] target(s) in 1.28 secs
Running `target/debug/tests`
it compiles & runs
My take is that in that case the ownership (lifetime ?) of the input parameter's data is passed to data_from_reader to value to unwraped_value to the temp ResponseJSON to the returned rocket::data::Outcome; so it seems ok.
With references, the temp ResponseJSON didn't outlive the function end, since it outlived the serde_json::Value from which it was created i.e. unwraped_value lifetime i.e. the function's end; hence the compiler issue.
Not 100% sure of my explaination though, would love your thoughts about that

Resources