222 lines
7.7 KiB
Rust
222 lines
7.7 KiB
Rust
|
#[cfg(not(target_env = "msvc"))]
|
||
|
use tikv_jemallocator::Jemalloc;
|
||
|
|
||
|
#[cfg(not(target_env = "msvc"))]
|
||
|
#[global_allocator]
|
||
|
static GLOBAL: Jemalloc = Jemalloc;
|
||
|
|
||
|
use anyhow::Result;
|
||
|
use clap::{Parser, Subcommand};
|
||
|
use diesel::prelude::*;
|
||
|
use uuidv7::Uuid;
|
||
|
|
||
|
use gitea_pages::{api, db, init, init_nginx, worker, CONFIG};
|
||
|
|
||
|
#[derive(Debug, Parser)]
|
||
|
#[clap(author, version, about, long_about = None)]
|
||
|
struct Cli {
|
||
|
#[clap(subcommand)]
|
||
|
commands: Commands,
|
||
|
}
|
||
|
|
||
|
#[derive(Debug, Subcommand)]
|
||
|
enum Commands {
|
||
|
#[clap(about = "Starts Celery worker")]
|
||
|
Worker,
|
||
|
|
||
|
#[clap(about = "Starts beat for Celery worker")]
|
||
|
Beat,
|
||
|
|
||
|
#[clap(about = "Starts API")]
|
||
|
Api,
|
||
|
}
|
||
|
|
||
|
#[tokio::main]
|
||
|
async fn main() -> Result<()> {
|
||
|
let args = Cli::parse();
|
||
|
|
||
|
env_logger::init();
|
||
|
|
||
|
init()?;
|
||
|
|
||
|
match args.commands {
|
||
|
Commands::Worker => {
|
||
|
init_nginx()?;
|
||
|
|
||
|
let worker_conn = &worker::POOL.get().await.get().await?;
|
||
|
worker_conn.display_pretty().await;
|
||
|
worker_conn.consume_from(&[worker::QUEUE_NAME]).await?;
|
||
|
}
|
||
|
Commands::Beat => {
|
||
|
worker::beat().await?.start().await?;
|
||
|
}
|
||
|
Commands::Api => {
|
||
|
use actix_cors::Cors;
|
||
|
use actix_web::{
|
||
|
http::header, middleware, web, App, HttpRequest, HttpResponse, HttpServer,
|
||
|
};
|
||
|
use anyhow::Result;
|
||
|
use juniper_actix::graphql_handler;
|
||
|
use serde::Deserialize;
|
||
|
use std::str::FromStr;
|
||
|
|
||
|
async fn not_found() -> &'static str {
|
||
|
"Not found!"
|
||
|
}
|
||
|
|
||
|
#[derive(Deserialize, Debug)]
|
||
|
struct Owner {
|
||
|
username: String,
|
||
|
}
|
||
|
|
||
|
#[derive(Deserialize, Debug)]
|
||
|
struct Repository {
|
||
|
owner: Owner,
|
||
|
name: String,
|
||
|
}
|
||
|
|
||
|
#[derive(Deserialize, Debug)]
|
||
|
struct Payload {
|
||
|
repository: Repository,
|
||
|
}
|
||
|
|
||
|
async fn webhook(req: HttpRequest, body: web::Bytes) -> HttpResponse {
|
||
|
if let Some(content_type) = req.headers().get(header::CONTENT_TYPE) {
|
||
|
if content_type.as_bytes() != b"application/json" {
|
||
|
return HttpResponse::BadRequest()
|
||
|
.body("Content-Type not application/json");
|
||
|
}
|
||
|
}
|
||
|
|
||
|
let key =
|
||
|
ring::hmac::Key::new(ring::hmac::HMAC_SHA256, CONFIG.gitea_secret.as_bytes());
|
||
|
|
||
|
let signature = match req.headers().get("X-Gitea-Signature") {
|
||
|
Some(x) => x,
|
||
|
None => return HttpResponse::BadRequest().body("X-Gitea-Signature not given"),
|
||
|
};
|
||
|
let mut tag = vec![0u8; signature.len() / 2];
|
||
|
if hex_simd::decode(signature.as_bytes(), hex_simd::Out::from_slice(&mut tag))
|
||
|
.is_err()
|
||
|
{
|
||
|
return HttpResponse::BadRequest().body("Could not decode signature");
|
||
|
}
|
||
|
|
||
|
if ring::hmac::verify(&key, body.as_ref(), &tag).is_err() {
|
||
|
return HttpResponse::BadRequest().body("Invalid signature");
|
||
|
}
|
||
|
|
||
|
let payload = match serde_json::from_slice::<Payload>(body.as_ref()) {
|
||
|
Ok(x) => x,
|
||
|
Err(_) => return HttpResponse::BadRequest().body("Payload has invalid JSON"),
|
||
|
};
|
||
|
|
||
|
let db_conn = &mut match db::POOL.get() {
|
||
|
Ok(x) => x,
|
||
|
Err(_) => return HttpResponse::InternalServerError().finish(),
|
||
|
};
|
||
|
let user_id = match db::schema::users::table
|
||
|
.select(db::schema::users::id)
|
||
|
.filter(db::schema::users::name.eq(payload.repository.owner.username))
|
||
|
.first::<Uuid>(db_conn)
|
||
|
.optional()
|
||
|
{
|
||
|
Ok(x) => match x {
|
||
|
Some(y) => y,
|
||
|
None => {
|
||
|
return HttpResponse::BadRequest().body("Repository is not allowed")
|
||
|
}
|
||
|
},
|
||
|
Err(_) => return HttpResponse::InternalServerError().finish(),
|
||
|
};
|
||
|
let repo_id = match db::schema::repositories::table
|
||
|
.select(db::schema::repositories::id)
|
||
|
.filter(db::schema::repositories::user_id.eq(user_id))
|
||
|
.filter(db::schema::repositories::name.eq(payload.repository.name))
|
||
|
.first::<Uuid>(db_conn)
|
||
|
.optional()
|
||
|
{
|
||
|
Ok(x) => match x {
|
||
|
Some(y) => y,
|
||
|
None => {
|
||
|
return HttpResponse::BadRequest().body("Repository is not allowed")
|
||
|
}
|
||
|
},
|
||
|
Err(_) => return HttpResponse::InternalServerError().finish(),
|
||
|
};
|
||
|
|
||
|
let worker_conn = match worker::POOL.get().await.get().await {
|
||
|
Ok(x) => x,
|
||
|
Err(_) => return HttpResponse::InternalServerError().finish(),
|
||
|
};
|
||
|
if worker_conn
|
||
|
.send_task(worker::get_repo::get_repo::new(repo_id))
|
||
|
.await
|
||
|
.is_err()
|
||
|
{
|
||
|
return HttpResponse::InternalServerError().finish();
|
||
|
}
|
||
|
|
||
|
HttpResponse::Ok().finish()
|
||
|
}
|
||
|
|
||
|
async fn graphql_route(
|
||
|
req: HttpRequest,
|
||
|
payload: web::Payload,
|
||
|
schema: web::Data<api::Schema>,
|
||
|
) -> Result<HttpResponse, actix_web::Error> {
|
||
|
let logged_in = match req
|
||
|
.headers()
|
||
|
.get(header::AUTHORIZATION)
|
||
|
.and_then(|x| x.to_str().ok())
|
||
|
{
|
||
|
Some(x) => match http_auth_basic::Credentials::from_str(x) {
|
||
|
Ok(cred) => {
|
||
|
cred.user_id == CONFIG.user && cred.password == *CONFIG.password
|
||
|
}
|
||
|
Err(_) => false,
|
||
|
},
|
||
|
None => false,
|
||
|
};
|
||
|
|
||
|
let context = api::Context {
|
||
|
db_pool: db::POOL.clone(),
|
||
|
worker_pool: worker::POOL.get().await.clone(),
|
||
|
loaders: api::context::Loaders::default(),
|
||
|
logged_in,
|
||
|
};
|
||
|
|
||
|
graphql_handler(&schema, &context, req, payload).await
|
||
|
}
|
||
|
|
||
|
HttpServer::new(move || {
|
||
|
App::new()
|
||
|
.app_data(web::Data::new(api::schema()))
|
||
|
.wrap(middleware::Logger::default())
|
||
|
.wrap(middleware::Compress::default())
|
||
|
.wrap(
|
||
|
Cors::default()
|
||
|
.allow_any_origin()
|
||
|
.allowed_methods(["POST", "GET"])
|
||
|
.allowed_headers([header::AUTHORIZATION, header::ACCEPT])
|
||
|
.allowed_header(header::CONTENT_TYPE)
|
||
|
.supports_credentials()
|
||
|
.max_age(3600),
|
||
|
)
|
||
|
.service(web::resource("/webhook").route(web::post().to(webhook)))
|
||
|
.service(
|
||
|
web::resource("/graphql")
|
||
|
.route(web::post().to(graphql_route))
|
||
|
.route(web::get().to(graphql_route)),
|
||
|
)
|
||
|
.default_service(web::to(not_found))
|
||
|
})
|
||
|
.bind(("0.0.0.0", 8080))?
|
||
|
.run()
|
||
|
.await?;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
Ok(())
|
||
|
}
|