gitea_pages/backend/src/main.rs
2023-05-21 18:58:10 +02:00

221 lines
7.7 KiB
Rust

#[cfg(not(target_env = "msvc"))]
use tikv_jemallocator::Jemalloc;
#[cfg(not(target_env = "msvc"))]
#[global_allocator]
static GLOBAL: Jemalloc = Jemalloc;
use anyhow::Result;
use clap::{Parser, Subcommand};
use diesel::prelude::*;
use uuidv7::Uuid;
use gitea_pages::{api, db, init, init_nginx, worker, CONFIG};
#[derive(Debug, Parser)]
#[clap(author, version, about, long_about = None)]
struct Cli {
#[clap(subcommand)]
commands: Commands,
}
#[derive(Debug, Subcommand)]
enum Commands {
#[clap(about = "Starts Celery worker")]
Worker,
#[clap(about = "Starts beat for Celery worker")]
Beat,
#[clap(about = "Starts API")]
Api,
}
#[tokio::main]
async fn main() -> Result<()> {
let args = Cli::parse();
env_logger::init();
init()?;
match args.commands {
Commands::Worker => {
init_nginx()?;
let worker_conn = &worker::POOL.get().await.get().await?;
worker_conn.display_pretty().await;
worker_conn.consume_from(&[worker::QUEUE_NAME]).await?;
}
Commands::Beat => {
worker::beat().await?.start().await?;
}
Commands::Api => {
use actix_cors::Cors;
use actix_web::{
http::header, middleware, web, App, HttpRequest, HttpResponse, HttpServer,
};
use anyhow::Result;
use juniper_actix::graphql_handler;
use serde::Deserialize;
use std::str::FromStr;
async fn not_found() -> &'static str {
"Not found!"
}
#[derive(Deserialize, Debug)]
struct Owner {
username: String,
}
#[derive(Deserialize, Debug)]
struct Repository {
owner: Owner,
name: String,
}
#[derive(Deserialize, Debug)]
struct Payload {
repository: Repository,
}
async fn webhook(req: HttpRequest, body: web::Bytes) -> HttpResponse {
if let Some(content_type) = req.headers().get(header::CONTENT_TYPE) {
if content_type.as_bytes() != b"application/json" {
return HttpResponse::BadRequest()
.body("Content-Type not application/json");
}
}
let key =
ring::hmac::Key::new(ring::hmac::HMAC_SHA256, CONFIG.gitea_secret.as_bytes());
let signature = match req.headers().get("X-Gitea-Signature") {
Some(x) => x,
None => return HttpResponse::BadRequest().body("X-Gitea-Signature not given"),
};
let mut tag = vec![0u8; signature.len() / 2];
if hex_simd::decode(signature.as_bytes(), hex_simd::Out::from_slice(&mut tag))
.is_err()
{
return HttpResponse::BadRequest().body("Could not decode signature");
}
if ring::hmac::verify(&key, body.as_ref(), &tag).is_err() {
return HttpResponse::BadRequest().body("Invalid signature");
}
let payload = match serde_json::from_slice::<Payload>(body.as_ref()) {
Ok(x) => x,
Err(_) => return HttpResponse::BadRequest().body("Payload has invalid JSON"),
};
let db_conn = &mut match db::POOL.get() {
Ok(x) => x,
Err(_) => return HttpResponse::InternalServerError().finish(),
};
let user_id = match db::schema::users::table
.select(db::schema::users::id)
.filter(db::schema::users::name.eq(payload.repository.owner.username))
.first::<Uuid>(db_conn)
.optional()
{
Ok(x) => match x {
Some(y) => y,
None => {
return HttpResponse::BadRequest().body("Repository is not allowed")
}
},
Err(_) => return HttpResponse::InternalServerError().finish(),
};
let repo_id = match db::schema::repositories::table
.select(db::schema::repositories::id)
.filter(db::schema::repositories::user_id.eq(user_id))
.filter(db::schema::repositories::name.eq(payload.repository.name))
.first::<Uuid>(db_conn)
.optional()
{
Ok(x) => match x {
Some(y) => y,
None => {
return HttpResponse::BadRequest().body("Repository is not allowed")
}
},
Err(_) => return HttpResponse::InternalServerError().finish(),
};
let worker_conn = match worker::POOL.get().await.get().await {
Ok(x) => x,
Err(_) => return HttpResponse::InternalServerError().finish(),
};
if worker_conn
.send_task(worker::get_repo::get_repo::new(repo_id))
.await
.is_err()
{
return HttpResponse::InternalServerError().finish();
}
HttpResponse::Ok().finish()
}
async fn graphql_route(
req: HttpRequest,
payload: web::Payload,
schema: web::Data<api::Schema>,
) -> Result<HttpResponse, actix_web::Error> {
let logged_in = match req
.headers()
.get(header::AUTHORIZATION)
.and_then(|x| x.to_str().ok())
{
Some(x) => match http_auth_basic::Credentials::from_str(x) {
Ok(cred) => {
cred.user_id == CONFIG.user && cred.password == *CONFIG.password
}
Err(_) => false,
},
None => false,
};
let context = api::Context {
db_pool: db::POOL.clone(),
worker_pool: worker::POOL.get().await.clone(),
loaders: api::context::Loaders::default(),
logged_in,
};
graphql_handler(&schema, &context, req, payload).await
}
HttpServer::new(move || {
App::new()
.app_data(web::Data::new(api::schema()))
.wrap(middleware::Logger::default())
.wrap(middleware::Compress::default())
.wrap(
Cors::default()
.allow_any_origin()
.allowed_methods(["POST", "GET"])
.allowed_headers([header::AUTHORIZATION, header::ACCEPT])
.allowed_header(header::CONTENT_TYPE)
.supports_credentials()
.max_age(3600),
)
.service(web::resource("/webhook").route(web::post().to(webhook)))
.service(
web::resource("/graphql")
.route(web::post().to(graphql_route))
.route(web::get().to(graphql_route)),
)
.default_service(web::to(not_found))
})
.bind(("0.0.0.0", 8080))?
.run()
.await?;
}
}
Ok(())
}