diff --git a/Cargo.lock b/Cargo.lock index 8aa6a23..5d39847 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -141,6 +141,12 @@ dependencies = [ "os_str_bytes", ] +[[package]] +name = "either" +version = "1.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f107b87b6afc2a64fd13cac55fe06d6c8859f12d4b14cbcdd2c67d0976781be" + [[package]] name = "flume" version = "0.10.13" @@ -304,6 +310,22 @@ dependencies = [ "itoa", ] +[[package]] +name = "humantime" +version = "2.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4" + +[[package]] +name = "humantime-serde" +version = "1.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57a3db5ea5923d99402c94e9feb261dc5ee9b4efa158b0315f788cf549cc200c" +dependencies = [ + "humantime", + "serde", +] + [[package]] name = "idna" version = "0.2.3" @@ -325,6 +347,15 @@ dependencies = [ "hashbrown", ] +[[package]] +name = "itertools" +version = "0.10.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a9a9d19fa1e79b6215ff29b9d6880b706147f16e9b1dbb1e4e5947b5b02bc5e3" +dependencies = [ + "either", +] + [[package]] name = "itoa" version = "1.0.2" @@ -461,6 +492,8 @@ version = "0.1.0" dependencies = [ "bytes", "clap", + "humantime-serde", + "itertools", "rumqttc", "serde", "serde_json", diff --git a/Cargo.toml b/Cargo.toml index b5449be..41e8534 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,10 +8,12 @@ edition = "2021" [dependencies] bytes = "1.1.0" clap = { version = "3.2.12", features = ["derive", "env"] } +humantime-serde = "1.1.1" +itertools = "0.10.3" rumqttc = { version = "0.13.0", features = ["url"], git = "https://github.com/bytebeamio/rumqtt" } serde = { version = "1.0.139", features = ["serde_derive"] } serde_json = "1.0.82" serialport = { version = "4.2.0", features = ["serde"] } -tokio = { version = "1.20.0", features = ["rt", "rt-multi-thread"] } +tokio = { version = "1.20.0", features = ["rt", "rt-multi-thread", "time"] } tokio-modbus = "0.5.3" tokio-serial = "5.4.3" diff --git a/src/main.rs b/src/main.rs index ea427c1..bbce7ab 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,7 +2,7 @@ use rumqttc::{self, AsyncClient, Event, Incoming, LastWill, MqttOptions, Publish use serde::{Deserialize, Serialize}; use serde_json::json; use std::{collections::HashMap, time::Duration}; -use tokio::sync::mpsc; +use tokio::{sync::mpsc, sync::oneshot, time::MissedTickBehavior}; use tokio_modbus::prelude::*; use clap::Parser; @@ -28,7 +28,7 @@ struct Cli { mqtt_topic_prefix: String, } -#[derive(Serialize, Deserialize)] +#[derive(Clone, Serialize, Deserialize)] #[serde(untagged)] enum ModbusProto { Tcp { @@ -77,13 +77,9 @@ fn default_modbus_parity() -> tokio_serial::Parity { tokio_serial::Parity::None } -#[derive(Serialize, Deserialize)] -struct Range { - address: u16, - size: u16, -} - // TODO: `scale`, `offset`, `precision` +// TODO: migrate `count` from `Range` into this enum to force the correct size? +#[derive(Clone, Serialize, Deserialize)] enum RegisterValueType { U8, U16, @@ -95,23 +91,35 @@ enum RegisterValueType { I64, F32, F64, - String, + // Array(u16, RegisterValueType), + String(u16), } -#[derive(Serialize, Deserialize)] +#[derive(Clone, Serialize, Deserialize)] struct RegisterParse { #[serde(default = "default_swap")] swap_bytes: bool, #[serde(default = "default_swap")] swap_words: bool, + + #[serde(rename = "type", skip_serializing_if = "Option::is_none")] + value_type: Option, } fn default_swap() -> bool { false } -#[derive(Serialize, Deserialize)] +#[derive(Clone, Serialize, Deserialize)] +struct Range { + address: u16, + + #[serde(alias = "size")] + count: u8, // Modbus limits to 125 in fact - https://github.com/slowtec/tokio-modbus/issues/112#issuecomment-1095316069= +} + +#[derive(Clone, Serialize, Deserialize)] struct Register { #[serde(flatten)] range: Range, @@ -119,30 +127,51 @@ struct Register { #[serde(skip_serializing_if = "Option::is_none")] name: Option, + #[serde(skip_serializing_if = "Option::is_none")] parse: Option, + + #[serde(with = "humantime_serde", default = "default_register_interval")] + interval: Duration, } -#[derive(Serialize, Deserialize)] +fn default_register_interval() -> Duration { + Duration::from_secs(10) +} + +#[derive(Clone, Serialize, Deserialize)] struct Connect { #[serde(flatten)] settings: ModbusProto, - // input_ranges: Vec, - // hold_ranges: Vec, - #[serde(default = "default_modbus_unit")] - slave: u8, // TODO make `Slave` but need custom deserializer I think + #[serde(default, skip_serializing_if = "Vec::is_empty")] + input: Vec, + + #[serde(default, skip_serializing_if = "Vec::is_empty")] + hold: Vec, + + #[serde(alias = "slave", default = "default_modbus_unit", with = "ext::Unit")] + unit: Unit, #[serde(default = "default_address_offset")] address_offset: i8, } -fn default_modbus_unit() -> u8 { - 0 +fn default_modbus_unit() -> Unit { + Slave(0) } fn default_address_offset() -> i8 { 0 } +type UnitId = SlaveId; +type Unit = Slave; +mod ext { + use serde::{Deserialize, Serialize}; + #[derive(Serialize, Deserialize)] + #[serde(remote = "tokio_modbus::slave::Slave")] + pub struct Unit(pub crate::UnitId); +} + #[derive(Serialize)] #[serde(rename_all = "lowercase")] enum ConnectState { @@ -204,7 +233,6 @@ async fn main() { enum DispatchCommand { Publish { topic: String, payload: Vec }, } - async fn mqtt_dispatcher( mut options: MqttOptions, prefix: String, @@ -344,6 +372,20 @@ async fn connection_registry( } } +#[derive(Debug)] +enum ModbusReadType { + Input, + Hold, +} + +#[derive(Debug)] +enum ModbusCommand { + Read(ModbusReadType, u16, u8, ModbusResponse), + Write(u16, Vec, ModbusResponse), +} + +type ModbusResponse = oneshot::Sender, std::io::Error>>; + async fn handle_connect( dispatcher: mpsc::Sender, id: ConnectionId, @@ -353,13 +395,13 @@ async fn handle_connect( println!("Starting connection handler for {}", id); match serde_json::from_slice::(&payload) { Ok(connect) => { - let slave = Slave(connect.slave); + let unit = connect.unit; // println!("{:?}", connect); let mut modbus = match connect.settings { ModbusProto::Tcp { ref host, port } => { let socket_addr = format!("{}:{}", host, port).parse().unwrap(); - tcp::connect_slave(socket_addr, slave).await.unwrap() + tcp::connect_slave(socket_addr, unit).await.unwrap() } ModbusProto::Rtu { ref tty, @@ -375,11 +417,11 @@ async fn handle_connect( .parity(parity) .stop_bits(stop_bits); let port = tokio_serial::SerialStream::open(&builder).unwrap(); - rtu::connect_slave(port, slave).await.unwrap() + rtu::connect_slave(port, unit).await.unwrap() } }; let status = ConnectStatus { - connect: connect, + connect: connect.clone(), status: ConnectState::Connected, }; dispatcher @@ -389,6 +431,107 @@ async fn handle_connect( }) .await .unwrap(); + + let (modbus_tx, mut modbus_rx) = mpsc::channel::(32); + tokio::spawn(async move { + while let Some(command) = modbus_rx.recv().await { + match command { + ModbusCommand::Read(read_type, address, count, responder) => { + let response = match read_type { + ModbusReadType::Input => { + modbus.read_input_registers(address, count as u16) + } + ModbusReadType::Hold => { + modbus.read_holding_registers(address, count as u16) + } + }; + + responder.send(response.await).unwrap(); + } + ModbusCommand::Write(address, data, responder) => { + responder + .send( + modbus + .write_multiple_registers(address, &data[..]) + .await + .map(|_| vec![]), + ) + .unwrap(); + } + } + } + }); + + use itertools::Itertools; + for (duration, registers) in &connect.input.into_iter().group_by(|r| r.interval) { + let registers: Vec = registers.collect(); + let id = id.clone(); + let modbus = modbus_tx.clone(); + let dispatcher = dispatcher.clone(); + let topic_prefix = topic_prefix.clone(); + + tokio::spawn(async move { + let mut interval = tokio::time::interval(duration); + interval.set_missed_tick_behavior(MissedTickBehavior::Delay); + + loop { + interval.tick().await; + for ref r in registers.iter() { + let address = if connect.address_offset >= 0 { + r.range.address.checked_add(connect.address_offset as u16) + } else { + r.range + .address + .checked_sub(connect.address_offset.unsigned_abs() as u16) + }; + if let Some(address) = address { + println!("Polling {}", address); + + let (tx, rx) = oneshot::channel(); + + modbus + .send(ModbusCommand::Read( + ModbusReadType::Input, + address, + r.range.count.into(), + tx, + )) + .await + .unwrap(); + + let values = rx.await.unwrap().unwrap(); + + let payload = + serde_json::to_vec(&json!({ "raw": values, })).unwrap(); + + dispatcher + .send(DispatchCommand::Publish { + topic: format!( + "{}/registers/{}/{}", + topic_prefix, id, r.range.address + ), + payload: payload.clone(), + }) + .await + .unwrap(); + + if let Some(name) = &r.name { + dispatcher + .send(DispatchCommand::Publish { + topic: format!( + "{}/registers/{}/{}", + topic_prefix, id, name + ), + payload: payload, + }) + .await + .unwrap(); + } + } + } + } + }); + } } Err(err) => { dispatcher