use actix_web::dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform}; use actix_web::{http, FromRequest, HttpMessage}; use anyhow::Result; use futures::future::{err, ok, LocalBoxFuture}; use jsonwebtoken::{DecodingKey, Validation}; use kdash_protocol::Percent; use serde::{Deserialize, Serialize}; use std::collections::{HashMap, HashSet}; use std::fs; use std::future::{ready, Ready}; use std::marker::PhantomData; use std::path::Path; use std::sync::Arc; use uuid::Uuid; pub mod config; pub mod device; pub mod handlers; pub const VERSION: &str = env!("CARGO_PKG_VERSION"); #[derive(Debug, Deserialize, Clone)] pub struct Config { pub devices: HashMap<Uuid, device::Config>, } pub fn read_config(path: &Path) -> Result<Config> { let buf = fs::read(path)?; let config = serde_json::from_slice(&buf)?; Ok(config) } pub struct AppState { pub devices: HashMap<Uuid, device::Config>, pub devices_api: HashMap<Uuid, kdash_protocol::Config>, pub devices_api_json: HashMap<Uuid, Vec<u8>>, pub device_ids: HashSet<Uuid>, pub jwt_decoding_key: DecodingKey, pub jwt_validation: Validation, } #[derive(Debug, Serialize, Deserialize, Clone)] pub struct Claims { pub exp: i64, pub sub: Uuid, } impl FromRequest for Claims { type Error = actix_web::Error; type Future = futures::future::Ready<Result<Self, Self::Error>>; fn from_request( req: &actix_web::HttpRequest, _payload: &mut actix_web::dev::Payload, ) -> Self::Future { match req.extensions().get::<Claims>() { Some(claims) => ok(claims.clone()), None => err(actix_web::error::ErrorUnauthorized("Unauthorized")), } } } pub struct Authority { pub decoding_key: DecodingKey, pub validation: Validation, } pub struct JwtAuth<Claims> { pub auth: Arc<Authority>, claims_marker: PhantomData<Claims>, } impl<Claims> JwtAuth<Claims> { pub fn new(auth: Arc<Authority>) -> Self { Self { auth, claims_marker: PhantomData, } } } impl<S, B, Claims> Transform<S, ServiceRequest> for JwtAuth<Claims> where S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = actix_web::Error>, S::Future: 'static, B: 'static, { type Response = ServiceResponse<B>; type Error = actix_web::Error; type InitError = (); type Transform = JwtAuthMiddleware<S, Claims>; type Future = Ready<Result<Self::Transform, Self::InitError>>; fn new_transform(&self, service: S) -> Self::Future { ready(Ok(JwtAuthMiddleware { service, auth: Arc::clone(&self.auth), claims_marker: PhantomData, })) } } pub struct JwtAuthMiddleware<S, Claims> { service: S, auth: Arc<Authority>, claims_marker: PhantomData<Claims>, } impl<S, B, C> Service<ServiceRequest> for JwtAuthMiddleware<S, C> where S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = actix_web::Error>, S::Future: 'static, B: 'static, { type Response = ServiceResponse<B>; type Error = actix_web::Error; type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>; forward_ready!(service); fn call(&self, req: ServiceRequest) -> Self::Future { let decoded = req .headers() .get(http::header::AUTHORIZATION) .and_then(|h| h.to_str().ok()) .filter(|s| s.starts_with("Bearer ")) .map(|s| &s[7..]) .and_then(|token| { jsonwebtoken::decode::<Claims>( token, &self.auth.decoding_key, &self.auth.validation, ) .ok() }); if let Some(data) = decoded { req.extensions_mut().insert(data.claims); dbg!(req.extensions()); } Box::pin(self.service.call(req)) } } pub struct MqttClient { pub client: rumqttc::v5::AsyncClient, pub topic: String, pub discovery_topic: Option<String>, } #[derive(Serialize)] pub struct MqttState { pub battery_charging: bool, pub battery_level: Percent, pub battery_current: i16, pub battery_voltage: i16, } #[derive(Serialize)] pub struct MqttDeviceDiscovery<'a> { #[serde(rename = "dev")] pub device: MqttDeviceDiscoveryDevice<'a>, #[serde(rename = "o")] pub origin: MqttDeviceDiscoveryOrigin<'a>, #[serde(rename = "cmps")] pub components: HashMap<String, MqttDeviceDiscoveryComponent<'a>>, pub state_topic: &'a str, pub qos: u8, } #[derive(Serialize)] pub struct MqttDeviceDiscoveryDevice<'a> { #[serde(rename = "ids")] pub identifiers: &'a str, pub name: String, // #[serde(rename = "mf")] // pub manufacturer: &'a str, // #[serde(rename = "mdl")] // pub default_manufacturer: &'a str, // #[serde(rename = "sw")] // pub sw_version: &'a str, // #[serde(rename = "sn")] // pub serial_number: &'a str, // #[serde(rename = "hw")] // pub hw_version: &'a str, } #[derive(Serialize)] pub struct MqttDeviceDiscoveryOrigin<'a> { pub name: &'a str, #[serde(rename = "sw")] pub software_version: &'a str, pub url: &'a str, } #[derive(Serialize)] pub struct MqttDeviceDiscoveryComponent<'a> { #[serde(rename = "p")] pub platform: &'a str, pub device_class: &'a str, #[serde(skip_serializing_if = "Option::is_none")] #[serde(default)] pub unit_of_measurement: Option<&'a str>, pub value_template: &'a str, pub unique_id: &'a str, }