use actix_web::dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform}; use actix_web::{http, FromRequest, HttpMessage}; use anyhow::Result; use futures::future::{err, ok, LocalBoxFuture}; use jsonwebtoken::{DecodingKey, Validation}; use kdash_protocol::Percent; use serde::{Deserialize, Serialize}; use std::collections::{HashMap, HashSet}; use std::fs; use std::future::{ready, Ready}; use std::marker::PhantomData; use std::path::Path; use std::sync::Arc; use uuid::Uuid; pub mod config; pub mod device; pub mod handlers; pub const VERSION: &str = env!("CARGO_PKG_VERSION"); #[derive(Debug, Deserialize, Clone)] pub struct Config { pub devices: HashMap, } pub fn read_config(path: &Path) -> Result { let buf = fs::read(path)?; let config = serde_json::from_slice(&buf)?; Ok(config) } pub struct AppState { pub devices: HashMap, pub devices_api: HashMap, pub devices_api_json: HashMap>, pub device_ids: HashSet, pub jwt_decoding_key: DecodingKey, pub jwt_validation: Validation, } #[derive(Debug, Serialize, Deserialize, Clone)] pub struct Claims { pub exp: i64, pub sub: Uuid, } impl FromRequest for Claims { type Error = actix_web::Error; type Future = futures::future::Ready>; fn from_request( req: &actix_web::HttpRequest, _payload: &mut actix_web::dev::Payload, ) -> Self::Future { match req.extensions().get::() { Some(claims) => ok(claims.clone()), None => err(actix_web::error::ErrorUnauthorized("Unauthorized")), } } } pub struct Authority { pub decoding_key: DecodingKey, pub validation: Validation, } pub struct JwtAuth { pub auth: Arc, claims_marker: PhantomData, } impl JwtAuth { pub fn new(auth: Arc) -> Self { Self { auth, claims_marker: PhantomData, } } } impl Transform for JwtAuth where S: Service, Error = actix_web::Error>, S::Future: 'static, B: 'static, { type Response = ServiceResponse; type Error = actix_web::Error; type InitError = (); type Transform = JwtAuthMiddleware; type Future = Ready>; fn new_transform(&self, service: S) -> Self::Future { ready(Ok(JwtAuthMiddleware { service, auth: Arc::clone(&self.auth), claims_marker: PhantomData, })) } } pub struct JwtAuthMiddleware { service: S, auth: Arc, claims_marker: PhantomData, } impl Service for JwtAuthMiddleware where S: Service, Error = actix_web::Error>, S::Future: 'static, B: 'static, { type Response = ServiceResponse; type Error = actix_web::Error; type Future = LocalBoxFuture<'static, Result>; forward_ready!(service); fn call(&self, req: ServiceRequest) -> Self::Future { let decoded = req .headers() .get(http::header::AUTHORIZATION) .and_then(|h| h.to_str().ok()) .filter(|s| s.starts_with("Bearer ")) .map(|s| &s[7..]) .and_then(|token| { jsonwebtoken::decode::( token, &self.auth.decoding_key, &self.auth.validation, ) .ok() }); if let Some(data) = decoded { req.extensions_mut().insert(data.claims); dbg!(req.extensions()); } Box::pin(self.service.call(req)) } } pub struct MqttClient { pub client: rumqttc::v5::AsyncClient, pub topic: String, pub discovery_topic: Option, } #[derive(Serialize)] pub struct MqttState { pub battery_charging: bool, pub battery_level: Percent, pub battery_current: i16, pub battery_voltage: i16, } #[derive(Serialize)] pub struct MqttDeviceDiscovery<'a> { #[serde(rename = "dev")] pub device: MqttDeviceDiscoveryDevice<'a>, #[serde(rename = "o")] pub origin: MqttDeviceDiscoveryOrigin<'a>, #[serde(rename = "cmps")] pub components: HashMap>, pub state_topic: &'a str, pub qos: u8, } #[derive(Serialize)] pub struct MqttDeviceDiscoveryDevice<'a> { #[serde(rename = "ids")] pub identifiers: &'a str, pub name: String, // #[serde(rename = "mf")] // pub manufacturer: &'a str, // #[serde(rename = "mdl")] // pub default_manufacturer: &'a str, // #[serde(rename = "sw")] // pub sw_version: &'a str, // #[serde(rename = "sn")] // pub serial_number: &'a str, // #[serde(rename = "hw")] // pub hw_version: &'a str, } #[derive(Serialize)] pub struct MqttDeviceDiscoveryOrigin<'a> { pub name: &'a str, #[serde(rename = "sw")] pub software_version: &'a str, pub url: &'a str, } #[derive(Serialize)] pub struct MqttDeviceDiscoveryComponent<'a> { #[serde(rename = "p")] pub platform: &'a str, pub device_class: &'a str, #[serde(skip_serializing_if = "Option::is_none")] #[serde(default)] pub unit_of_measurement: Option<&'a str>, pub value_template: &'a str, pub unique_id: &'a str, }