221 lines
5.5 KiB
Rust
221 lines
5.5 KiB
Rust
use actix_web::dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform};
|
|
use actix_web::{http, FromRequest, HttpMessage};
|
|
use anyhow::Result;
|
|
use futures::future::{err, ok, LocalBoxFuture};
|
|
use jsonwebtoken::{DecodingKey, Validation};
|
|
use kdash_protocol::Percent;
|
|
use serde::{Deserialize, Serialize};
|
|
use std::collections::{HashMap, HashSet};
|
|
use std::fs;
|
|
use std::future::{ready, Ready};
|
|
use std::marker::PhantomData;
|
|
use std::path::Path;
|
|
use std::sync::Arc;
|
|
use uuid::Uuid;
|
|
|
|
pub mod config;
|
|
pub mod device;
|
|
pub mod handlers;
|
|
|
|
pub const VERSION: &str = env!("CARGO_PKG_VERSION");
|
|
|
|
#[derive(Debug, Deserialize, Clone)]
|
|
pub struct Config {
|
|
pub devices: HashMap<Uuid, device::Config>,
|
|
}
|
|
|
|
pub fn read_config(path: &Path) -> Result<Config> {
|
|
let buf = fs::read(path)?;
|
|
let config = serde_json::from_slice(&buf)?;
|
|
|
|
Ok(config)
|
|
}
|
|
|
|
pub struct AppState {
|
|
pub devices: HashMap<Uuid, device::Config>,
|
|
pub devices_api: HashMap<Uuid, kdash_protocol::Config>,
|
|
pub devices_api_json: HashMap<Uuid, Vec<u8>>,
|
|
pub device_ids: HashSet<Uuid>,
|
|
pub jwt_decoding_key: DecodingKey,
|
|
pub jwt_validation: Validation,
|
|
}
|
|
|
|
#[derive(Debug, Serialize, Deserialize, Clone)]
|
|
pub struct Claims {
|
|
pub exp: i64,
|
|
pub sub: Uuid,
|
|
}
|
|
|
|
impl FromRequest for Claims {
|
|
type Error = actix_web::Error;
|
|
type Future = futures::future::Ready<Result<Self, Self::Error>>;
|
|
|
|
fn from_request(
|
|
req: &actix_web::HttpRequest,
|
|
_payload: &mut actix_web::dev::Payload,
|
|
) -> Self::Future {
|
|
match req.extensions().get::<Claims>() {
|
|
Some(claims) => ok(claims.clone()),
|
|
None => err(actix_web::error::ErrorUnauthorized("Unauthorized")),
|
|
}
|
|
}
|
|
}
|
|
|
|
pub struct Authority {
|
|
pub decoding_key: DecodingKey,
|
|
pub validation: Validation,
|
|
}
|
|
|
|
pub struct JwtAuth<Claims> {
|
|
pub auth: Arc<Authority>,
|
|
claims_marker: PhantomData<Claims>,
|
|
}
|
|
|
|
impl<Claims> JwtAuth<Claims> {
|
|
pub fn new(auth: Arc<Authority>) -> Self {
|
|
Self {
|
|
auth,
|
|
claims_marker: PhantomData,
|
|
}
|
|
}
|
|
}
|
|
|
|
impl<S, B, Claims> Transform<S, ServiceRequest> for JwtAuth<Claims>
|
|
where
|
|
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = actix_web::Error>,
|
|
S::Future: 'static,
|
|
B: 'static,
|
|
{
|
|
type Response = ServiceResponse<B>;
|
|
type Error = actix_web::Error;
|
|
type InitError = ();
|
|
type Transform = JwtAuthMiddleware<S, Claims>;
|
|
type Future = Ready<Result<Self::Transform, Self::InitError>>;
|
|
|
|
fn new_transform(&self, service: S) -> Self::Future {
|
|
ready(Ok(JwtAuthMiddleware {
|
|
service,
|
|
auth: Arc::clone(&self.auth),
|
|
claims_marker: PhantomData,
|
|
}))
|
|
}
|
|
}
|
|
|
|
pub struct JwtAuthMiddleware<S, Claims> {
|
|
service: S,
|
|
auth: Arc<Authority>,
|
|
claims_marker: PhantomData<Claims>,
|
|
}
|
|
|
|
impl<S, B, C> Service<ServiceRequest> for JwtAuthMiddleware<S, C>
|
|
where
|
|
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = actix_web::Error>,
|
|
S::Future: 'static,
|
|
B: 'static,
|
|
{
|
|
type Response = ServiceResponse<B>;
|
|
type Error = actix_web::Error;
|
|
type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
|
|
|
|
forward_ready!(service);
|
|
|
|
fn call(&self, req: ServiceRequest) -> Self::Future {
|
|
let decoded = req
|
|
.headers()
|
|
.get(http::header::AUTHORIZATION)
|
|
.and_then(|h| h.to_str().ok())
|
|
.filter(|s| s.starts_with("Bearer "))
|
|
.map(|s| &s[7..])
|
|
.and_then(|token| {
|
|
jsonwebtoken::decode::<Claims>(
|
|
token,
|
|
&self.auth.decoding_key,
|
|
&self.auth.validation,
|
|
)
|
|
.ok()
|
|
});
|
|
|
|
if let Some(data) = decoded {
|
|
req.extensions_mut().insert(data.claims);
|
|
dbg!(req.extensions());
|
|
}
|
|
|
|
Box::pin(self.service.call(req))
|
|
}
|
|
}
|
|
|
|
pub struct MqttClient {
|
|
pub client: rumqttc::v5::AsyncClient,
|
|
pub topic: String,
|
|
pub discovery_topic: Option<String>,
|
|
}
|
|
|
|
#[derive(Serialize)]
|
|
pub struct MqttState {
|
|
pub battery_charging: bool,
|
|
pub battery_level: Percent,
|
|
pub battery_current: i16,
|
|
pub battery_voltage: i16,
|
|
}
|
|
|
|
#[derive(Serialize)]
|
|
pub struct MqttDeviceDiscovery<'a> {
|
|
#[serde(rename = "dev")]
|
|
pub device: MqttDeviceDiscoveryDevice<'a>,
|
|
|
|
#[serde(rename = "o")]
|
|
pub origin: MqttDeviceDiscoveryOrigin<'a>,
|
|
|
|
#[serde(rename = "cmps")]
|
|
pub components: HashMap<String, MqttDeviceDiscoveryComponent<'a>>,
|
|
|
|
pub state_topic: &'a str,
|
|
|
|
pub qos: u8,
|
|
}
|
|
|
|
#[derive(Serialize)]
|
|
pub struct MqttDeviceDiscoveryDevice<'a> {
|
|
#[serde(rename = "ids")]
|
|
pub identifiers: &'a str,
|
|
|
|
pub name: String,
|
|
// #[serde(rename = "mf")]
|
|
// pub manufacturer: &'a str,
|
|
|
|
// #[serde(rename = "mdl")]
|
|
// pub default_manufacturer: &'a str,
|
|
|
|
// #[serde(rename = "sw")]
|
|
// pub sw_version: &'a str,
|
|
|
|
// #[serde(rename = "sn")]
|
|
// pub serial_number: &'a str,
|
|
|
|
// #[serde(rename = "hw")]
|
|
// pub hw_version: &'a str,
|
|
}
|
|
|
|
#[derive(Serialize)]
|
|
pub struct MqttDeviceDiscoveryOrigin<'a> {
|
|
pub name: &'a str,
|
|
|
|
#[serde(rename = "sw")]
|
|
pub software_version: &'a str,
|
|
|
|
pub url: &'a str,
|
|
}
|
|
|
|
#[derive(Serialize)]
|
|
pub struct MqttDeviceDiscoveryComponent<'a> {
|
|
#[serde(rename = "p")]
|
|
pub platform: &'a str,
|
|
pub device_class: &'a str,
|
|
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
#[serde(default)]
|
|
pub unit_of_measurement: Option<&'a str>,
|
|
|
|
pub value_template: &'a str,
|
|
pub unique_id: &'a str,
|
|
}
|