1
0
Fork 0

wip: connector mostly working

refactor
Bo Jeanes 2022-09-08 10:16:44 +10:00
parent a742fd2986
commit cd3cc6f19c
12 changed files with 1375 additions and 1030 deletions

View File

@ -19,7 +19,7 @@ itertools = "0.10.3"
rumqttc = { version = "0.15.0", default-features = true, features = ["url"] } # https://github.com/bytebeamio/rumqtt/issues/446
rust_decimal = { version = "1.26.1", features = ["serde-arbitrary-precision", "serde-float", "serde_json", "maths"] }
serde = { version = "1.0.139", features = ["serde_derive"] }
serde_json = "1.0.82"
serde_json = { version = "1.0.82", features = ["raw_value"] }
serialport = { version = "4.2.0", optional = true, features = ["serde"] }
thiserror = "1.0.33"
tokio = { version = "1.20.0", features = ["rt", "rt-multi-thread", "time", "signal"] }

View File

@ -1,5 +1,6 @@
use clap::Parser;
use modbus_mqtt::{server, Result};
use rumqttc::MqttOptions;
use url::Url;
#[derive(Parser, Debug)]
@ -14,7 +15,8 @@ struct Cli {
env = "MQTT_URL",
// validator = "is_mqtt_url",
default_value = "mqtt://localhost:1883/modbus-mqtt",
value_hint = clap::ValueHint::Url
value_hint = clap::ValueHint::Url,
help = "Pass the topic prefix as the URL path"
)]
url: Url,
}
@ -22,24 +24,33 @@ struct Cli {
#[tokio::main]
async fn main() -> Result<()> {
tracing_subscriber::fmt::init();
let mut args = Cli::parse();
let prefix = args
.url
let Cli { mut url } = Cli::parse();
let mut prefix = url
.path()
.trim_start_matches('/')
.split('/')
.next()
.unwrap_or(env!("CARGO_PKG_NAME"))
.trim_end_matches('/')
.to_owned();
// FIXME: if they pass "?client_id=foo" param, skip this
args.url
.query_pairs_mut()
.append_pair("client_id", env!("CARGO_PKG_NAME"))
.finish();
let options: MqttOptions = match url.clone().try_into() {
Ok(options) => options,
Err(rumqttc::OptionError::ClientId) => {
let url = url
.query_pairs_mut()
.append_pair("client_id", env!("CARGO_PKG_NAME"))
.finish()
.clone();
url.try_into()?
}
Err(other) => return Err(other.into()),
};
server::run(prefix, args.url.try_into()?, tokio::signal::ctrl_c()).await?;
if prefix.is_empty() {
prefix = options.client_id();
}
server::run(prefix, options, tokio::signal::ctrl_c()).await?;
Ok(())
}

View File

@ -3,6 +3,6 @@ use tokio::sync::mpsc::Sender;
use crate::{modbus::register::Register, mqtt};
/// Describes the register to Home Assistant
fn configure(register: Register, tx: Sender<mqtt::Message>) -> crate::Result<()> {
fn configure(_register: Register, _tx: Sender<mqtt::Message>) -> crate::Result<()> {
Ok(())
}

View File

@ -1,15 +1,10 @@
use rumqttc::{self, AsyncClient, Event, Incoming, LastWill, MqttOptions, Publish, QoS};
use serde::Serialize;
use serde_json::json;
use std::{collections::HashMap, time::Duration};
use tokio::{sync::mpsc, sync::oneshot, time::MissedTickBehavior};
use tokio_modbus::prelude::*;
use tracing::{debug, error, info};
use rumqttc::{self};
use tracing::error;
use thiserror::Error;
mod shutdown;
pub(crate) use shutdown::Shutdown;
pub mod homeassistant;
pub mod modbus;
@ -22,15 +17,33 @@ pub enum Error {
#[error(transparent)]
IOError(#[from] std::io::Error),
#[error("{0}")]
Other(std::borrow::Cow<'static, str>),
#[error(transparent)]
MQTTOptionError(#[from] rumqttc::OptionError),
#[error(transparent)]
MQTTClientError(#[from] rumqttc::ClientError),
#[error(transparent)]
MQTTConnectionError(#[from] rumqttc::ConnectionError),
#[error(transparent)]
InvalidSocketAddr(#[from] std::net::AddrParseError),
#[error(transparent)]
SerialError(#[from] tokio_serial::Error),
#[error("RecvError")]
RecvError,
#[error("SendError")]
SendError,
#[error("Unrecognised modbus protocol")]
UnrecognisedModbusProtocol,
#[error("{0}")]
Other(std::borrow::Cow<'static, str>),
#[error("Unknown")]
Unknown,
}
@ -48,391 +61,384 @@ impl From<&'static str> for Error {
pub type Result<T> = std::result::Result<T, Error>;
#[derive(Serialize)]
#[serde(rename_all = "lowercase")]
enum MainStatus {
Running,
Stopped,
}
// #[derive(Debug)]
// enum DispatchCommand {
// Publish { topic: String, payload: Vec<u8> },
// }
// #[tracing::instrument(level = "debug")]
// async fn mqtt_dispatcher(
// mut options: MqttOptions,
// prefix: String,
// registry: mpsc::Sender<RegistryCommand>,
// mut rx: mpsc::Receiver<DispatchCommand>,
// ) {
// info!("Connecting to MQTT broker...");
#[derive(Debug)]
enum DispatchCommand {
Publish { topic: String, payload: Vec<u8> },
}
#[tracing::instrument(level = "debug")]
async fn mqtt_dispatcher(
mut options: MqttOptions,
prefix: String,
registry: mpsc::Sender<RegistryCommand>,
mut rx: mpsc::Receiver<DispatchCommand>,
) {
info!("Connecting to MQTT broker...");
// options.set_last_will(LastWill {
// topic: format!("{}/status", prefix),
// message: serde_json::to_vec(&json!({
// "status": MainStatus::Stopped,
// }))
// .unwrap()
// .into(),
// qos: QoS::AtMostOnce,
// retain: false,
// });
options.set_last_will(LastWill {
topic: format!("{}/status", prefix),
message: serde_json::to_vec(&json!({
"status": MainStatus::Stopped,
}))
.unwrap()
.into(),
qos: QoS::AtMostOnce,
retain: false,
});
// let (client, mut eventloop) = AsyncClient::new(options, 10);
let (client, mut eventloop) = AsyncClient::new(options, 10);
// client
// .publish(
// format!("{}/status", prefix),
// QoS::AtMostOnce,
// false,
// serde_json::to_vec(&json!({
// "status": MainStatus::Running,
// }))
// .unwrap(),
// )
// .await
// .unwrap();
client
.publish(
format!("{}/status", prefix),
QoS::AtMostOnce,
false,
serde_json::to_vec(&json!({
"status": MainStatus::Running,
}))
.unwrap(),
)
.await
.unwrap();
// client
// .subscribe(format!("{}/connect/#", prefix), QoS::AtMostOnce)
// .await
// .unwrap();
client
.subscribe(format!("{}/connect/#", prefix), QoS::AtMostOnce)
.await
.unwrap();
// let rx_loop_handler = {
// let client = client.clone();
// tokio::spawn(async move {
// info!("Start dispatcher rx loop");
// while let Some(command) = rx.recv().await {
// match command {
// DispatchCommand::Publish { topic, payload } => {
// client
// .publish(topic, QoS::AtMostOnce, false, payload)
// .await
// .unwrap();
// }
// }
// }
// })
// };
let rx_loop_handler = {
let client = client.clone();
tokio::spawn(async move {
info!("Start dispatcher rx loop");
while let Some(command) = rx.recv().await {
match command {
DispatchCommand::Publish { topic, payload } => {
client
.publish(topic, QoS::AtMostOnce, false, payload)
.await
.unwrap();
}
}
}
})
};
// while let Ok(event) = eventloop.poll().await {
// use Event::{Incoming as In, Outgoing as Out};
while let Ok(event) = eventloop.poll().await {
use Event::{Incoming as In, Outgoing as Out};
// match event {
// Out(_) => (),
// In(Incoming::ConnAck(_)) => info!("Connected to MQTT!"),
// In(Incoming::PingResp | Incoming::SubAck(_)) => (),
match event {
Out(_) => (),
In(Incoming::ConnAck(_)) => info!("Connected to MQTT!"),
In(Incoming::PingResp | Incoming::SubAck(_)) => (),
// In(Incoming::Publish(Publish { topic, payload, .. })) => {
// debug!("{} -> {:?}", &topic, &payload);
In(Incoming::Publish(Publish { topic, payload, .. })) => {
debug!("{} -> {:?}", &topic, &payload);
// match topic.split('/').collect::<Vec<&str>>()[..] {
// [p, "connect", conn_name] if p == prefix.as_str() => {
// registry
// .send(RegistryCommand::Connect {
// id: conn_name.to_string(),
// details: payload,
// })
// .await
// .unwrap();
// }
// _ => (),
// };
// }
// _ => {
// debug!("{:?}", event);
// }
// }
// }
match topic.split('/').collect::<Vec<&str>>()[..] {
[p, "connect", conn_name] if p == prefix.as_str() => {
registry
.send(RegistryCommand::Connect {
id: conn_name.to_string(),
details: payload,
})
.await
.unwrap();
}
_ => (),
};
}
_ => {
debug!("{:?}", event);
}
}
}
// rx_loop_handler.await.unwrap();
// }
rx_loop_handler.await.unwrap();
}
// type ConnectionId = String;
type ConnectionId = String;
// #[derive(Debug)]
// enum RegistryCommand {
// Connect {
// id: ConnectionId,
// details: bytes::Bytes,
// },
// Disconnect(ConnectionId),
// }
#[derive(Debug)]
enum RegistryCommand {
Connect {
id: ConnectionId,
details: bytes::Bytes,
},
Disconnect(ConnectionId),
}
// type RegistryDb = HashMap<ConnectionId, tokio::task::JoinHandle<()>>;
type RegistryDb = HashMap<ConnectionId, tokio::task::JoinHandle<()>>;
// #[tracing::instrument(level = "debug")]
// async fn connection_registry(
// prefix: String,
// dispatcher: mpsc::Sender<DispatchCommand>,
// mut rx: mpsc::Receiver<RegistryCommand>,
// ) {
// info!("Starting connection registry...");
// let mut db: RegistryDb = HashMap::new();
#[tracing::instrument(level = "debug")]
async fn connection_registry(
prefix: String,
dispatcher: mpsc::Sender<DispatchCommand>,
mut rx: mpsc::Receiver<RegistryCommand>,
) {
info!("Starting connection registry...");
let mut db: RegistryDb = HashMap::new();
// while let Some(command) = rx.recv().await {
// use RegistryCommand::*;
// match command {
// Disconnect(id) => {
// if let Some(handle) = db.remove(&id) {
// handle.abort();
// }
// }
// Connect { id, details } => {
// info!(id, payload = ?details, "Establishing connection");
// let prefix = prefix.clone();
// let dispatcher = dispatcher.clone();
while let Some(command) = rx.recv().await {
use RegistryCommand::*;
match command {
Disconnect(id) => {
if let Some(handle) = db.remove(&id) {
handle.abort();
}
}
Connect { id, details } => {
info!(id, payload = ?details, "Establishing connection");
let prefix = prefix.clone();
let dispatcher = dispatcher.clone();
// if let Some(handle) = db.remove(&id) {
// handle.abort();
// }
if let Some(handle) = db.remove(&id) {
handle.abort();
}
// db.insert(
// id.clone(),
// tokio::spawn(handle_connect(dispatcher, id, prefix, details)),
// );
// }
// _ => error!("unimplemented"),
// }
// }
// }
db.insert(
id.clone(),
tokio::spawn(handle_connect(dispatcher, id, prefix, details)),
);
}
_ => error!("unimplemented"),
}
}
}
// #[derive(Clone, Copy, Debug)]
// enum ModbusReadType {
// Input,
// Hold,
// }
#[derive(Clone, Copy, Debug)]
enum ModbusReadType {
Input,
Hold,
}
// #[derive(Debug)]
// enum ModbusCommand {
// Read(ModbusReadType, u16, u8, ModbusResponse),
// Write(u16, Vec<u16>, ModbusResponse),
// }
#[derive(Debug)]
enum ModbusCommand {
Read(ModbusReadType, u16, u8, ModbusResponse),
Write(u16, Vec<u16>, ModbusResponse),
}
// type ModbusResponse = oneshot::Sender<Result<Vec<u16>>>;
type ModbusResponse = oneshot::Sender<Result<Vec<u16>>>;
// #[tracing::instrument(level = "debug")]
// async fn handle_connect(
// dispatcher: mpsc::Sender<DispatchCommand>,
// id: ConnectionId,
// topic_prefix: String,
// payload: bytes::Bytes,
// ) {
// use modbus::config::*;
// use modbus::ConnectState;
// info!("Starting connection handler for {}", id);
// match serde_json::from_slice::<Connect>(&payload) {
// Ok(connect) => {
// let unit = connect.unit;
#[tracing::instrument(level = "debug")]
async fn handle_connect(
dispatcher: mpsc::Sender<DispatchCommand>,
id: ConnectionId,
topic_prefix: String,
payload: bytes::Bytes,
) {
use modbus::config::*;
use modbus::ConnectState;
info!("Starting connection handler for {}", id);
match serde_json::from_slice::<Connect>(&payload) {
Ok(connect) => {
let unit = connect.unit;
// let mut modbus: tokio_modbus::client::Context = match connect.settings {
// #[cfg(feature = "winet-s")]
// ModbusProto::SungrowWiNetS { ref host } => {
// tokio_modbus_winets::connect_slave(host, unit)
// .await
// .unwrap()
// }
// #[cfg(feature = "tcp")]
// ModbusProto::Tcp { ref host, port } => {
// let socket_addr = format!("{}:{}", host, port).parse().unwrap();
// tcp::connect_slave(socket_addr, unit).await.unwrap()
// }
// #[cfg(feature = "rtu")]
// ModbusProto::Rtu {
// ref tty,
// baud_rate,
// data_bits,
// stop_bits,
// flow_control,
// parity,
// } => {
// let builder = tokio_serial::new(tty, baud_rate)
// .data_bits(data_bits)
// .flow_control(flow_control)
// .parity(parity)
// .stop_bits(stop_bits);
// let port = tokio_serial::SerialStream::open(&builder).unwrap();
// rtu::connect_slave(port, unit).await.unwrap()
// }
// ModbusProto::Unknown => {
// error!("Unrecognised protocol");
// return;
// }
// };
// let status = modbus::ConnectStatus {
// connect: connect.clone(),
// status: ConnectState::Connected,
// };
// dispatcher
// .send(DispatchCommand::Publish {
// topic: format!("{}/status/{}", topic_prefix, id),
// payload: serde_json::to_vec(&status).unwrap(),
// })
// .await
// .unwrap();
let mut modbus: tokio_modbus::client::Context = match connect.settings {
#[cfg(feature = "winet-s")]
ModbusProto::SungrowWiNetS { ref host } => {
tokio_modbus_winets::connect_slave(host, unit)
.await
.unwrap()
}
#[cfg(feature = "tcp")]
ModbusProto::Tcp { ref host, port } => {
let socket_addr = format!("{}:{}", host, port).parse().unwrap();
tcp::connect_slave(socket_addr, unit).await.unwrap()
}
#[cfg(feature = "rtu")]
ModbusProto::Rtu {
ref tty,
baud_rate,
data_bits,
stop_bits,
flow_control,
parity,
} => {
let builder = tokio_serial::new(tty, baud_rate)
.data_bits(data_bits)
.flow_control(flow_control)
.parity(parity)
.stop_bits(stop_bits);
let port = tokio_serial::SerialStream::open(&builder).unwrap();
rtu::connect_slave(port, unit).await.unwrap()
}
ModbusProto::Unknown => {
error!("Unrecognised protocol");
return;
}
};
let status = modbus::ConnectStatus {
connect: connect.clone(),
status: ConnectState::Connected,
};
dispatcher
.send(DispatchCommand::Publish {
topic: format!("{}/status/{}", topic_prefix, id),
payload: serde_json::to_vec(&status).unwrap(),
})
.await
.unwrap();
// let (modbus_tx, mut modbus_rx) = mpsc::channel::<ModbusCommand>(32);
// tokio::spawn(async move {
// while let Some(command) = modbus_rx.recv().await {
// match command {
// ModbusCommand::Read(read_type, address, count, responder) => {
// let response = match read_type {
// ModbusReadType::Input => {
// modbus.read_input_registers(address, count as u16)
// }
// ModbusReadType::Hold => {
// modbus.read_holding_registers(address, count as u16)
// }
// };
let (modbus_tx, mut modbus_rx) = mpsc::channel::<ModbusCommand>(32);
tokio::spawn(async move {
while let Some(command) = modbus_rx.recv().await {
match command {
ModbusCommand::Read(read_type, address, count, responder) => {
let response = match read_type {
ModbusReadType::Input => {
modbus.read_input_registers(address, count as u16)
}
ModbusReadType::Hold => {
modbus.read_holding_registers(address, count as u16)
}
};
// responder.send(response.await.map_err(Into::into)).unwrap();
// }
// ModbusCommand::Write(address, data, responder) => {
// responder
// .send(
// modbus
// .read_write_multiple_registers(
// address,
// data.len() as u16,
// address,
// &data[..],
// )
// .await
// .map_err(Into::into),
// )
// .unwrap();
// }
// }
// }
// });
responder.send(response.await.map_err(Into::into)).unwrap();
}
ModbusCommand::Write(address, data, responder) => {
responder
.send(
modbus
.read_write_multiple_registers(
address,
data.len() as u16,
address,
&data[..],
)
.await
.map_err(Into::into),
)
.unwrap();
}
}
}
});
// use itertools::Itertools;
// for (duration, registers) in &connect.input.into_iter().group_by(|r| r.interval) {
// let registers_prefix = format!("{}/input/{}", topic_prefix, id);
use itertools::Itertools;
for (duration, registers) in &connect.input.into_iter().group_by(|r| r.interval) {
let registers_prefix = format!("{}/input/{}", topic_prefix, id);
// tokio::spawn(watch_registers(
// ModbusReadType::Input,
// connect.address_offset,
// duration,
// registers.collect(),
// modbus_tx.clone(),
// dispatcher.clone(),
// registers_prefix,
// ));
// }
// for (duration, registers) in &connect.hold.into_iter().group_by(|r| r.interval) {
// let registers_prefix = format!("{}/hold/{}", topic_prefix, id);
tokio::spawn(watch_registers(
ModbusReadType::Input,
connect.address_offset,
duration,
registers.collect(),
modbus_tx.clone(),
dispatcher.clone(),
registers_prefix,
));
}
for (duration, registers) in &connect.hold.into_iter().group_by(|r| r.interval) {
let registers_prefix = format!("{}/hold/{}", topic_prefix, id);
// tokio::spawn(watch_registers(
// ModbusReadType::Hold,
// connect.address_offset,
// duration,
// registers.collect(),
// modbus_tx.clone(),
// dispatcher.clone(),
// registers_prefix,
// ));
// }
// }
// Err(err) => {
// dispatcher
// .send(DispatchCommand::Publish {
// topic: format!("{}/status/{}", topic_prefix, id),
// payload: serde_json::to_vec(&json!({
// "status": ConnectState::Errored,
// "error": format!("Invalid config: {}", err),
// }))
// .unwrap(),
// })
// .await
// .unwrap();
// }
// }
// }
tokio::spawn(watch_registers(
ModbusReadType::Hold,
connect.address_offset,
duration,
registers.collect(),
modbus_tx.clone(),
dispatcher.clone(),
registers_prefix,
));
}
}
Err(err) => {
dispatcher
.send(DispatchCommand::Publish {
topic: format!("{}/status/{}", topic_prefix, id),
payload: serde_json::to_vec(&json!({
"status": ConnectState::Errored,
"error": format!("Invalid config: {}", err),
}))
.unwrap(),
})
.await
.unwrap();
}
}
}
// #[tracing::instrument(level = "debug")]
// async fn watch_registers(
// read_type: ModbusReadType,
// address_offset: i8,
// duration: Duration,
// registers: Vec<modbus::config::Register>,
// modbus: mpsc::Sender<ModbusCommand>,
// dispatcher: mpsc::Sender<DispatchCommand>,
// registers_prefix: String,
// ) -> ! {
// let mut interval = tokio::time::interval(duration);
// interval.set_missed_tick_behavior(MissedTickBehavior::Delay);
#[tracing::instrument(level = "debug")]
async fn watch_registers(
read_type: ModbusReadType,
address_offset: i8,
duration: Duration,
registers: Vec<modbus::config::Register>,
modbus: mpsc::Sender<ModbusCommand>,
dispatcher: mpsc::Sender<DispatchCommand>,
registers_prefix: String,
) -> ! {
let mut interval = tokio::time::interval(duration);
interval.set_missed_tick_behavior(MissedTickBehavior::Delay);
// loop {
// interval.tick().await;
// for r in registers.iter() {
// let address = if address_offset >= 0 {
// r.address.checked_add(address_offset as u16)
// } else {
// r.address.checked_sub(address_offset.unsigned_abs() as u16)
// };
// if let Some(address) = address {
// let size = r.parse.value_type.size();
// debug!(
// name = r.name.as_ref().unwrap_or(&"".to_string()),
// address,
// size,
// register_type = ?read_type,
// value_type = r.parse.value_type.type_name(),
// "Polling register",
// );
loop {
interval.tick().await;
for r in registers.iter() {
let address = if address_offset >= 0 {
r.address.checked_add(address_offset as u16)
} else {
r.address.checked_sub(address_offset.unsigned_abs() as u16)
};
if let Some(address) = address {
let size = r.parse.value_type.size();
debug!(
name = r.name.as_ref().unwrap_or(&"".to_string()),
address,
size,
register_type = ?read_type,
value_type = r.parse.value_type.type_name(),
"Polling register",
);
// let (tx, rx) = oneshot::channel();
let (tx, rx) = oneshot::channel();
// modbus
// .send(ModbusCommand::Read(read_type, address, size, tx))
// .await
// .unwrap();
modbus
.send(ModbusCommand::Read(read_type, address, size, tx))
.await
.unwrap();
// // FIXME: definitely getting errors here that need to be handled
// //
// // thread 'tokio-runtime-worker' panicked at 'called `Result::unwrap()` on an `Err` value: Error { kind: UnexpectedEof, message: "failed to fill whole buffer" }'
// // thread 'tokio-runtime-worker' panicked at 'called `Result::unwrap()` on an `Err` value: Custom { kind: InvalidData, error: "Invalid data length: 0" }'
// // thread 'tokio-runtime-worker' panicked at 'called `Result::unwrap()` on an `Err` value: Os { code: 36, kind: Uncategorized, message: "Operation now in progress" }'
// // thread 'tokio-runtime-worker' panicked at 'called `Result::unwrap()` on an `Err` value: Os { code: 35, kind: WouldBlock, message: "Resource temporarily unavailable" }
// //
// // Splitting out the two awaits so I can see if all of the above panics come from the same await or some from one vs the other:
// let response = rx.await.unwrap(); // await may have errorer on receiving
// let words = response.unwrap(); // received message is also a result which may be a (presumably Modbus?) error
// FIXME: definitely getting errors here that need to be handled
//
// thread 'tokio-runtime-worker' panicked at 'called `Result::unwrap()` on an `Err` value: Error { kind: UnexpectedEof, message: "failed to fill whole buffer" }'
// thread 'tokio-runtime-worker' panicked at 'called `Result::unwrap()` on an `Err` value: Custom { kind: InvalidData, error: "Invalid data length: 0" }'
// thread 'tokio-runtime-worker' panicked at 'called `Result::unwrap()` on an `Err` value: Os { code: 36, kind: Uncategorized, message: "Operation now in progress" }'
// thread 'tokio-runtime-worker' panicked at 'called `Result::unwrap()` on an `Err` value: Os { code: 35, kind: WouldBlock, message: "Resource temporarily unavailable" }
//
// Splitting out the two awaits so I can see if all of the above panics come from the same await or some from one vs the other:
let response = rx.await.unwrap(); // await may have errorer on receiving
let words = response.unwrap(); // received message is also a result which may be a (presumably Modbus?) error
// let swapped_words = r.apply_swaps(&words);
let swapped_words = r.apply_swaps(&words);
// let value = r.parse_words(&swapped_words);
let value = r.parse_words(&swapped_words);
// debug!(
// name = r.name.as_ref().unwrap_or(&"".to_string()),
// address,
// %value,
// raw = ?words,
// "Received value",
// );
debug!(
name = r.name.as_ref().unwrap_or(&"".to_string()),
address,
%value,
raw = ?words,
"Received value",
);
// let payload = serde_json::to_vec(&json!({ "value": value, "raw": words })).unwrap();
let payload = serde_json::to_vec(&json!({ "value": value, "raw": words })).unwrap();
// dispatcher
// .send(DispatchCommand::Publish {
// topic: format!("{}/{}", registers_prefix, r.address),
// payload: payload.clone(),
// })
// .await
// .unwrap();
dispatcher
.send(DispatchCommand::Publish {
topic: format!("{}/{}", registers_prefix, r.address),
payload: payload.clone(),
})
.await
.unwrap();
if let Some(name) = &r.name {
dispatcher
.send(DispatchCommand::Publish {
topic: format!("{}/{}", registers_prefix, name),
payload,
})
.await
.unwrap();
}
}
}
}
}
// if let Some(name) = &r.name {
// dispatcher
// .send(DispatchCommand::Publish {
// topic: format!("{}/{}", registers_prefix, name),
// payload,
// })
// .await
// .unwrap();
// }
// }
// }
// }
// }

View File

@ -1,518 +0,0 @@
use serde::{Deserialize, Serialize};
use std::time::Duration;
#[cfg(test)]
use serde_json::json;
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(tag = "proto", rename_all = "lowercase")]
pub enum ModbusProto {
#[cfg(feature = "tcp")]
Tcp {
host: String,
#[serde(default = "default_modbus_port")]
port: u16,
},
#[cfg(feature = "rtu")]
#[serde(rename_all = "lowercase")]
Rtu {
// tty: std::path::PathBuf,
tty: String,
baud_rate: u32,
#[serde(default = "default_modbus_data_bits")]
data_bits: tokio_serial::DataBits, // TODO: allow this to be represented as a number instead of string
#[serde(default = "default_modbus_stop_bits")]
stop_bits: tokio_serial::StopBits, // TODO: allow this to be represented as a number instead of string
#[serde(default = "default_modbus_flow_control")]
flow_control: tokio_serial::FlowControl,
#[serde(default = "default_modbus_parity")]
parity: tokio_serial::Parity,
},
#[cfg(feature = "winet-s")]
#[serde(rename = "winet-s")]
SungrowWiNetS { host: String },
// Predominantly for if the binary is compiled with no default features for some reason.
#[serde(other)]
Unknown,
}
fn default_modbus_port() -> u16 {
502
}
#[cfg(feature = "rtu")]
fn default_modbus_data_bits() -> tokio_serial::DataBits {
tokio_serial::DataBits::Eight
}
#[cfg(feature = "rtu")]
fn default_modbus_stop_bits() -> tokio_serial::StopBits {
tokio_serial::StopBits::One
}
#[cfg(feature = "rtu")]
fn default_modbus_flow_control() -> tokio_serial::FlowControl {
tokio_serial::FlowControl::None
}
#[cfg(feature = "rtu")]
fn default_modbus_parity() -> tokio_serial::Parity {
tokio_serial::Parity::None
}
#[derive(Clone, Default, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase", default)]
pub struct RegisterNumericAdjustment {
pub scale: i8, // powers of 10 (0 = no adjustment, 1 = x10, -1 = /10)
pub offset: i8,
// precision: Option<u8>,
}
#[derive(Clone, Default, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum RegisterNumeric {
U8,
#[default]
U16,
U32,
U64,
#[serde(alias = "s8")]
I8,
#[serde(alias = "s16")]
I16,
#[serde(alias = "s32")]
I32,
#[serde(alias = "s64")]
I64,
F32,
F64,
}
impl RegisterNumeric {
// Modbus limits sequential reads to 125 apparently, so 8-bit should be fine - https://github.com/slowtec/tokio-modbus/issues/112#issuecomment-1095316069=
fn size(&self) -> u8 {
use RegisterNumeric::*;
// Each Modbus register holds 16-bits, so count is half what the byte count would be
match self {
U8 | I8 => 1,
U16 | I16 => 1,
U32 | I32 | F32 => 2,
U64 | I64 | F64 => 4,
}
}
fn type_name(&self) -> String {
format!("{:?}", *self).to_lowercase()
}
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(tag = "type", rename = "string")]
pub struct RegisterString {
length: u8,
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(tag = "type", rename = "array")]
pub struct RegisterArray {
count: u8,
#[serde(default)]
of: RegisterNumeric,
// Arrays are only of numeric types, so we can apply an adjustment here
#[serde(flatten, skip_serializing_if = "IsDefault::is_default")]
adjust: RegisterNumericAdjustment,
}
impl Default for RegisterArray {
fn default() -> Self {
Self {
count: 1,
of: Default::default(),
adjust: Default::default(),
}
}
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(untagged)]
pub enum RegisterValueType {
Numeric {
#[serde(rename = "type", default)]
of: RegisterNumeric,
#[serde(flatten, skip_serializing_if = "IsDefault::is_default")]
adjust: RegisterNumericAdjustment,
},
Array(RegisterArray),
String(RegisterString),
}
impl RegisterValueType {
pub fn type_name(&self) -> String {
match *self {
RegisterValueType::Numeric { ref of, .. } => of.type_name(),
RegisterValueType::Array(_) => "array".to_owned(),
RegisterValueType::String(_) => "string".to_owned(),
}
}
}
impl Default for RegisterValueType {
fn default() -> Self {
RegisterValueType::Numeric {
of: Default::default(),
adjust: Default::default(),
}
}
}
impl RegisterValueType {
// Modbus limits sequential reads to 125 apparently, so 8-bit should be fine - https://github.com/slowtec/tokio-modbus/issues/112#issuecomment-1095316069=
pub fn size(&self) -> u8 {
use RegisterValueType::*;
match self {
Numeric { of, .. } => of.size(),
String(RegisterString { length }) => *length,
Array(RegisterArray { of, count, .. }) => of.size() * count,
}
}
}
#[derive(Clone, Default, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(transparent)]
pub struct Swap(pub bool);
trait IsDefault {
fn is_default(&self) -> bool;
}
impl<T> IsDefault for T
where
T: Default + PartialEq,
{
fn is_default(&self) -> bool {
*self == Default::default()
}
}
#[derive(Clone, Debug, PartialEq, Eq, Default, Serialize, Deserialize)]
pub struct RegisterParse {
#[serde(default, skip_serializing_if = "IsDefault::is_default")]
pub swap_bytes: Swap,
#[serde(default, skip_serializing_if = "IsDefault::is_default")]
pub swap_words: Swap,
#[serde(flatten, skip_serializing_if = "IsDefault::is_default")]
pub value_type: RegisterValueType,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Register {
pub address: u16,
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
#[serde(flatten, default, skip_serializing_if = "IsDefault::is_default")]
pub parse: RegisterParse,
#[serde(
with = "humantime_serde",
default = "default_register_interval",
alias = "period",
alias = "duration"
)]
pub interval: Duration,
}
fn default_register_interval() -> Duration {
Duration::from_secs(60)
}
#[derive(Clone, Serialize, Deserialize)]
pub struct Connect {
#[serde(flatten)]
pub settings: ModbusProto,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub input: Vec<Register>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub hold: Vec<Register>,
#[serde(
alias = "slave",
default = "tokio_modbus::slave::Slave::broadcast",
with = "Unit"
)]
pub unit: crate::modbus::Unit,
#[serde(default = "default_address_offset")]
pub address_offset: i8,
}
#[derive(Serialize, Deserialize)]
#[serde(remote = "tokio_modbus::slave::Slave")]
struct Unit(crate::modbus::UnitId);
fn default_address_offset() -> i8 {
0
}
#[test]
fn parse_minimal_tcp_connect_config() {
let result = serde_json::from_value::<Connect>(json!({
"proto": "tcp",
"host": "1.1.1.1"
}));
let connect = result.unwrap();
assert!(matches!(
connect.settings,
ModbusProto::Tcp {
ref host,
port: 502
} if host == "1.1.1.1"
))
}
#[test]
fn parse_full_tcp_connect_config() {
let _ = serde_json::from_value::<Connect>(json!({
"proto": "tcp",
"host": "10.10.10.219",
"unit": 1,
"address_offset": -1,
"input": [
{
"address": 5017,
"type": "u32",
"name": "dc_power",
"swap_words": false,
"period": "3s"
},
{
"address": 5008,
"type": "s16",
"name": "internal_temperature",
"period": "1m"
},
{
"address": 13008,
"type": "s32",
"name": "load_power",
"swap_words": false,
"period": "3s"
},
{
"address": 13010,
"type": "s32",
"name": "export_power",
"swap_words": false,
"period": "3s"
},
{
"address": 13022,
"name": "battery_power",
"period": "3s"
},
{
"address": 13023,
"name": "battery_level",
"period": "1m"
},
{
"address": 13024,
"name": "battery_health",
"period": "10m"
}
],
"hold": [
{
"address": 13058,
"name": "max_soc",
"period": "90s"
},
{
"address": 13059,
"name": "min_soc",
"period": "90s"
}
]
}))
.unwrap();
}
#[test]
fn parse_minimal_rtu_connect_config() {
let result = serde_json::from_value::<Connect>(json!({
"proto": "rtu",
"tty": "/dev/ttyUSB0",
"baud_rate": 9600,
}));
let connect = result.unwrap();
use tokio_serial::*;
assert!(matches!(
connect.settings,
ModbusProto::Rtu {
ref tty,
baud_rate: 9600,
data_bits: DataBits::Eight,
stop_bits: StopBits::One,
flow_control: FlowControl::None,
parity: Parity::None,
..
} if tty == "/dev/ttyUSB0"
))
}
#[test]
fn parse_complete_rtu_connect_config() {
let result = serde_json::from_value::<Connect>(json!({
"proto": "rtu",
"tty": "/dev/ttyUSB0",
"baud_rate": 12800,
// TODO: make lowercase words work
"data_bits": "Seven", // TODO: make 7 work
"stop_bits": "Two", // TODO: make 2 work
"flow_control": "Software",
"parity": "Even",
}));
let connect = result.unwrap();
use tokio_serial::*;
assert!(matches!(
connect.settings,
ModbusProto::Rtu {
ref tty,
baud_rate: 12800,
data_bits: DataBits::Seven,
stop_bits: StopBits::Two,
flow_control: FlowControl::Software,
parity: Parity::Even,
..
} if tty == "/dev/ttyUSB0"
),);
}
#[test]
fn parse_empty_register_parser_defaults() {
let empty = serde_json::from_value::<RegisterParse>(json!({}));
assert!(matches!(
empty.unwrap(),
RegisterParse {
swap_bytes: Swap(false),
swap_words: Swap(false),
value_type: RegisterValueType::Numeric {
of: RegisterNumeric::U16,
adjust: RegisterNumericAdjustment {
scale: 0,
offset: 0,
}
}
}
));
}
#[test]
fn parse_register_parser_type() {
let result = serde_json::from_value::<RegisterParse>(json!({
"type": "s32"
}));
assert!(matches!(
result.unwrap().value_type,
RegisterValueType::Numeric {
of: RegisterNumeric::I32,
..
}
));
}
#[test]
fn parse_register_parser_array() {
let result = serde_json::from_value::<RegisterParse>(json!({
"type": "array",
"of": "s32",
"count": 10,
}));
let payload = result.unwrap();
// println!("{:?}", payload);
// println!("{}", serde_json::to_string_pretty(&payload).unwrap());
assert!(matches!(
payload.value_type,
RegisterValueType::Array(RegisterArray {
of: RegisterNumeric::I32,
count: 10,
..
})
));
}
#[test]
fn parse_register_parser_array_implicit_u16() {
let result = serde_json::from_value::<RegisterParse>(json!({
"type": "array",
"count": 10,
}));
let payload = result.unwrap();
// println!("{:?}", payload);
// println!("{}", serde_json::to_string_pretty(&payload).unwrap());
assert!(matches!(
payload.value_type,
RegisterValueType::Array(RegisterArray {
of: RegisterNumeric::U16,
count: 10,
..
})
));
}
#[test]
fn parse_register_parser_string() {
let result = serde_json::from_value::<RegisterParse>(json!({
"type": "string",
"length": 10,
}));
let payload = result.unwrap();
// println!("{:?}", payload);
// println!("{}", serde_json::to_string_pretty(&payload).unwrap());
assert!(matches!(
payload.value_type,
RegisterValueType::String(RegisterString { length: 10, .. })
));
}
#[test]
fn parse_register_parser_scale_etc() {
let result = serde_json::from_value::<RegisterParse>(json!({
"type": "s32",
"scale": -1,
"offset": 20,
}));
assert!(matches!(
result.unwrap().value_type,
RegisterValueType::Numeric {
of: RegisterNumeric::I32,
adjust: RegisterNumericAdjustment {
scale: -1,
offset: 20
}
}
));
}

View File

@ -1 +1,176 @@
pub struct Connection {}
use crate::modbus::{self};
use serde::Deserialize;
use std::convert::TryFrom;
use tokio::{select, sync::mpsc};
use tokio_modbus::client::{rtu, tcp, Context as ModbusClient};
use tracing::error;
use crate::{mqtt, shutdown::Shutdown};
// TODO make this into run() and have it spawn the task
pub(crate) async fn new(
config: Config,
mqtt: mqtt::Handle,
shutdown: Shutdown,
) -> crate::Result<Connection> {
let client = match config.settings {
#[cfg(feature = "winet-s")]
ModbusProto::SungrowWiNetS { ref host } => {
tokio_modbus_winets::connect_slave(host, config.unit).await?
}
#[cfg(feature = "tcp")]
ModbusProto::Tcp { ref host, port } => {
let socket_addr = format!("{}:{}", host, port).parse()?;
tcp::connect_slave(socket_addr, config.unit).await?
}
#[cfg(feature = "rtu")]
ModbusProto::Rtu {
ref tty,
baud_rate,
data_bits,
stop_bits,
flow_control,
parity,
} => {
let builder = tokio_serial::new(tty, baud_rate)
.data_bits(data_bits)
.flow_control(flow_control)
.parity(parity)
.stop_bits(stop_bits);
let port = tokio_serial::SerialStream::open(&builder)?;
rtu::connect_slave(port, config.unit).await?
}
ModbusProto::Unknown => {
error!("Unrecognised protocol");
return Err(crate::Error::UnrecognisedModbusProtocol);
}
};
let (tx, rx) = mpsc::channel(32);
Ok(Connection {
rx,
client,
mqtt,
shutdown,
})
}
pub struct Connection {
client: ModbusClient,
mqtt: mqtt::Handle,
shutdown: Shutdown,
rx: mpsc::Receiver<Message>,
}
enum Message {}
#[derive(Clone)]
pub struct Handler {
tx: mpsc::Sender<Message>,
}
impl Connection {
pub async fn run(mut self) -> crate::Result<()> {
select! {
_ = self.shutdown.recv() => {
return Ok(());
}
}
}
// pub fn handle(&self) -> Handle {}
}
#[derive(Debug, Deserialize)]
pub(crate) struct Config {
#[serde(flatten)]
pub settings: ModbusProto,
#[serde(
alias = "slave",
default = "tokio_modbus::slave::Slave::broadcast",
with = "Unit"
)]
pub unit: modbus::Unit,
#[serde(default)]
pub address_offset: i8,
}
#[derive(Deserialize)]
#[serde(remote = "tokio_modbus::slave::Slave")]
pub(crate) struct Unit(crate::modbus::UnitId);
#[derive(Clone, Debug, Deserialize)]
#[serde(tag = "proto", rename_all = "lowercase")]
pub(crate) enum ModbusProto {
#[cfg(feature = "tcp")]
Tcp {
host: String,
#[serde(default = "default_modbus_port")]
port: u16,
},
#[cfg(feature = "rtu")]
#[serde(rename_all = "lowercase")]
Rtu {
// tty: std::path::PathBuf,
tty: String,
baud_rate: u32,
#[serde(default = "default_modbus_data_bits")]
data_bits: tokio_serial::DataBits, // TODO: allow this to be represented as a number instead of string
#[serde(default = "default_modbus_stop_bits")]
stop_bits: tokio_serial::StopBits, // TODO: allow this to be represented as a number instead of string
#[serde(default = "default_modbus_flow_control")]
flow_control: tokio_serial::FlowControl,
#[serde(default = "default_modbus_parity")]
parity: tokio_serial::Parity,
},
#[cfg(feature = "winet-s")]
#[serde(rename = "winet-s")]
SungrowWiNetS { host: String },
// Predominantly for if the binary is compiled with no default features for some reason.
#[serde(other)]
Unknown,
}
pub(crate) fn default_modbus_port() -> u16 {
502
}
#[cfg(feature = "rtu")]
pub(crate) fn default_modbus_data_bits() -> tokio_serial::DataBits {
tokio_serial::DataBits::Eight
}
#[cfg(feature = "rtu")]
pub(crate) fn default_modbus_stop_bits() -> tokio_serial::StopBits {
tokio_serial::StopBits::One
}
#[cfg(feature = "rtu")]
pub(crate) fn default_modbus_flow_control() -> tokio_serial::FlowControl {
tokio_serial::FlowControl::None
}
#[cfg(feature = "rtu")]
pub(crate) fn default_modbus_parity() -> tokio_serial::Parity {
tokio_serial::Parity::None
}
impl TryFrom<Config> for Connection {
type Error = crate::Error;
fn try_from(_value: Config) -> Result<Self, Self::Error> {
todo!()
}
}

View File

@ -0,0 +1,154 @@
use crate::modbus::{connection, register};
use crate::mqtt::{Payload, Scopable};
use crate::{mqtt, shutdown::Shutdown};
use serde::Deserialize;
use serde_json::value::RawValue as RawJSON;
use tokio::select;
use tracing::{debug, error, info};
/// The topic filter under the prefix to look for connection configs
const TOPIC: &str = "+/connect";
/// Responsible for monitoring MQTT topic for connection configs
pub struct Connector {
mqtt: mqtt::Handle,
shutdown: Shutdown,
// connections: Vec<connection::Handle>,
}
pub(crate) fn new(mqtt: mqtt::Handle, shutdown: Shutdown) -> Connector {
Connector {
mqtt,
shutdown,
// connections: vec![],
}
}
impl Connector {
pub async fn run(&mut self) -> crate::Result<()> {
let mut new_connection = self.mqtt.subscribe(TOPIC).await?;
loop {
select! {
Some(Payload { bytes, topic }) = new_connection.recv() => {
// `unwrap()` is safe here because of the shape of valid topics and the fact that we are subcribed
// to a topic under a prefix.
let connection_id = topic.rsplit('/').nth_back(1).unwrap();
let mqtt = self.mqtt.scoped(connection_id);
debug!(?connection_id, ?bytes, ?topic, "Received connection config");
if let Err(error) = parse_and_connect(bytes, mqtt, self.shutdown.clone()).await {
error!(?connection_id, ?error, "Error creating connection");
}
},
_ = self.shutdown.recv() => {
info!("shutting down connector");
break;
},
}
}
Ok(())
}
}
async fn parse_and_connect(
bytes: bytes::Bytes,
mqtt: mqtt::Handle,
shutdown: Shutdown,
) -> crate::Result<()> {
match serde_json::from_slice(&bytes) {
Err(_) => mqtt.publish("state", "invalid").await?,
Ok(Config {
connection:
connection::Config {
settings: connection::ModbusProto::Unknown,
..
},
..
}) => mqtt.publish("state", "unknown_proto").await?,
Ok(config) => {
debug!(?config);
connect(config, mqtt, shutdown).await?;
}
}
Ok(())
}
async fn connect(config: Config<'_>, mqtt: mqtt::Handle, shutdown: Shutdown) -> crate::Result<()> {
if shutdown.is_shutdown() {
return Ok(());
}
let Config {
connection: settings,
input,
holding,
} = config;
mqtt.publish("state", "connecting").await?;
// connection isn't able to be Send, so we have to create connection inside its run task and find out if instantiation failed by connection on a channel.
{
let mqtt = mqtt.clone();
let (tx, rx) = tokio::sync::oneshot::channel::<crate::Result<()>>();
tokio::spawn(async move {
match connection::new(settings, mqtt.clone(), shutdown).await {
Ok(connection) => {
if let Err(e) = mqtt.publish("state", "connected").await {
tx.send(Err(e))
} else {
tx.send(Ok(()))
}
.expect("unexpected closed receiver");
if let Err(error) = connection.run().await {
error!(?error, "Modbus connection quit unexpectedly");
}
}
Err(e) => {
tx.send(Err(e)).expect("unexpected closed receiver");
}
}
});
rx.await.map_err(|_| crate::Error::RecvError)??;
}
for reg in input {
let mqtt = mqtt.scoped("input");
if let Ok(r) = serde_json::from_slice::<register::AddressedRegister>(reg.get().as_bytes()) {
debug!(?r);
let bytes: bytes::Bytes = reg.get().as_bytes().to_owned().into();
mqtt.publish(r.address.to_string(), bytes).await?;
}
}
for reg in holding {
let mqtt = mqtt.scoped("holding");
if let Ok(r) = serde_json::from_slice::<register::AddressedRegister>(reg.get().as_bytes()) {
debug!(?r);
let bytes: bytes::Bytes = reg.get().as_bytes().to_owned().into();
mqtt.publish(r.address.to_string(), bytes).await?;
}
}
Ok(())
}
/// Wrapper around `modbus::connection::Config` that can include some registers inline, which the connector will
/// re-publish to the appropriate topic once the connection is established.
#[derive(Debug, Deserialize)]
struct Config<'a> {
#[serde(flatten)]
connection: connection::Config,
// Allow registers to be defined inline, but capture them as raw JSON so that if they have incorrect schema, we can
// still establish the Modbus connection. Valid registers will be re-emitted as individual register configs to MQTT,
// to be picked up by the connection.
#[serde(default, borrow)]
pub input: Vec<&'a RawJSON>,
#[serde(alias = "hold", default, borrow)]
pub holding: Vec<&'a RawJSON>,
}

View File

@ -1,10 +1,10 @@
use rust_decimal::{prelude::FromPrimitive, Decimal};
use serde::Serialize;
use self::config::{Register, RegisterValueType};
use self::register::{Register, RegisterValueType};
pub mod config;
pub mod connection;
pub mod connector;
pub mod register;
pub use connection::Connection;
@ -18,20 +18,20 @@ pub enum ConnectState {
Errored,
}
#[derive(Serialize)]
pub struct ConnectStatus {
#[serde(flatten)]
pub connect: config::Connect,
pub status: ConnectState,
}
// #[derive(Serialize)]
// pub struct ConnectStatus {
// #[serde(flatten)]
// pub connect: config::Connect,
// pub status: ConnectState,
// }
pub type UnitId = tokio_modbus::prelude::SlaveId;
pub type Unit = tokio_modbus::prelude::Slave;
impl RegisterValueType {
pub fn parse_words(&self, words: &[u16]) -> serde_json::Value {
use self::config::RegisterValueType as T;
use self::config::{RegisterArray, RegisterNumeric as N, RegisterString};
use self::register::RegisterValueType as T;
use self::register::{RegisterArray, RegisterNumeric as N, RegisterString};
use serde_json::json;
let bytes: Vec<u8> = words.iter().flat_map(|v| v.to_ne_bytes()).collect();
@ -132,7 +132,7 @@ impl Register {
use pretty_assertions::assert_eq;
#[test]
fn test_parse_1() {
use self::config::{RegisterParse, Swap};
use self::register::{RegisterParse, Swap};
use serde_json::json;
let reg = Register {
@ -143,8 +143,8 @@ fn test_parse_1() {
swap_bytes: Swap(false),
swap_words: Swap(false),
value_type: RegisterValueType::Numeric {
of: config::RegisterNumeric::I32,
adjust: config::RegisterNumericAdjustment {
of: register::RegisterNumeric::I32,
adjust: register::RegisterNumericAdjustment {
scale: 0,
offset: 0,
},

View File

@ -1 +1,428 @@
pub struct Register {}
use serde::{Deserialize, Serialize};
use std::time::Duration;
#[derive(Clone, Default, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase", default)]
pub struct RegisterNumericAdjustment {
pub scale: i8, // powers of 10 (0 = no adjustment, 1 = x10, -1 = /10)
pub offset: i8,
// precision: Option<u8>,
}
#[derive(Clone, Default, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum RegisterNumeric {
U8,
#[default]
U16,
U32,
U64,
#[serde(alias = "s8")]
I8,
#[serde(alias = "s16")]
I16,
#[serde(alias = "s32")]
I32,
#[serde(alias = "s64")]
I64,
F32,
F64,
}
impl RegisterNumeric {
// Modbus limits sequential reads to 125 apparently, so 8-bit should be fine - https://github.com/slowtec/tokio-modbus/issues/112#issuecomment-1095316069=
fn size(&self) -> u8 {
use RegisterNumeric::*;
// Each Modbus register holds 16-bits, so count is half what the byte count would be
match self {
U8 | I8 => 1,
U16 | I16 => 1,
U32 | I32 | F32 => 2,
U64 | I64 | F64 => 4,
}
}
fn type_name(&self) -> String {
format!("{:?}", *self).to_lowercase()
}
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(tag = "type", rename = "string")]
pub struct RegisterString {
length: u8,
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(tag = "type", rename = "array")]
pub struct RegisterArray {
count: u8,
#[serde(default)]
of: RegisterNumeric,
// Arrays are only of numeric types, so we can apply an adjustment here
#[serde(flatten, skip_serializing_if = "IsDefault::is_default")]
adjust: RegisterNumericAdjustment,
}
impl Default for RegisterArray {
fn default() -> Self {
Self {
count: 1,
of: Default::default(),
adjust: Default::default(),
}
}
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(untagged)]
pub enum RegisterValueType {
Numeric {
#[serde(rename = "type", default)]
of: RegisterNumeric,
#[serde(flatten, skip_serializing_if = "IsDefault::is_default")]
adjust: RegisterNumericAdjustment,
},
Array(RegisterArray),
String(RegisterString),
}
impl RegisterValueType {
pub fn type_name(&self) -> String {
match *self {
RegisterValueType::Numeric { ref of, .. } => of.type_name(),
RegisterValueType::Array(_) => "array".to_owned(),
RegisterValueType::String(_) => "string".to_owned(),
}
}
}
impl Default for RegisterValueType {
fn default() -> Self {
RegisterValueType::Numeric {
of: Default::default(),
adjust: Default::default(),
}
}
}
impl RegisterValueType {
// Modbus limits sequential reads to 125 apparently, so 8-bit should be fine - https://github.com/slowtec/tokio-modbus/issues/112#issuecomment-1095316069=
pub fn size(&self) -> u8 {
use RegisterValueType::*;
match self {
Numeric { of, .. } => of.size(),
String(RegisterString { length }) => *length,
Array(RegisterArray { of, count, .. }) => of.size() * count,
}
}
}
#[derive(Clone, Default, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(transparent)]
pub struct Swap(pub bool);
trait IsDefault {
fn is_default(&self) -> bool;
}
impl<T> IsDefault for T
where
T: Default + PartialEq,
{
fn is_default(&self) -> bool {
*self == Default::default()
}
}
#[derive(Clone, Debug, PartialEq, Eq, Default, Serialize, Deserialize)]
pub struct RegisterParse {
#[serde(default, skip_serializing_if = "IsDefault::is_default")]
pub swap_bytes: Swap,
#[serde(default, skip_serializing_if = "IsDefault::is_default")]
pub swap_words: Swap,
#[serde(flatten, skip_serializing_if = "IsDefault::is_default")]
pub value_type: RegisterValueType,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Register {
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
#[serde(flatten, default, skip_serializing_if = "IsDefault::is_default")]
pub parse: RegisterParse,
#[serde(
with = "humantime_serde",
default = "default_register_interval",
alias = "period",
alias = "duration"
)]
pub interval: Duration,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct AddressedRegister {
pub address: u16,
#[serde(flatten)]
pub register: Register,
}
fn default_register_interval() -> Duration {
Duration::from_secs(60)
}
// #[test]
// fn parse_minimal_tcp_connect_config() {
// let result = serde_json::from_value::<Connect>(json!({
// "proto": "tcp",
// "host": "1.1.1.1"
// }));
// let connect = result.unwrap();
// assert!(matches!(
// connect.settings,
// ModbusProto::Tcp {
// ref host,
// port: 502
// } if host == "1.1.1.1"
// ))
// }
// #[test]
// fn parse_full_tcp_connect_config() {
// let _ = serde_json::from_value::<Connect>(json!({
// "proto": "tcp",
// "host": "10.10.10.219",
// "unit": 1,
// "address_offset": -1,
// "input": [
// {
// "address": 5017,
// "type": "u32",
// "name": "dc_power",
// "swap_words": false,
// "period": "3s"
// },
// {
// "address": 5008,
// "type": "s16",
// "name": "internal_temperature",
// "period": "1m"
// },
// {
// "address": 13008,
// "type": "s32",
// "name": "load_power",
// "swap_words": false,
// "period": "3s"
// },
// {
// "address": 13010,
// "type": "s32",
// "name": "export_power",
// "swap_words": false,
// "period": "3s"
// },
// {
// "address": 13022,
// "name": "battery_power",
// "period": "3s"
// },
// {
// "address": 13023,
// "name": "battery_level",
// "period": "1m"
// },
// {
// "address": 13024,
// "name": "battery_health",
// "period": "10m"
// }
// ],
// "hold": [
// {
// "address": 13058,
// "name": "max_soc",
// "period": "90s"
// },
// {
// "address": 13059,
// "name": "min_soc",
// "period": "90s"
// }
// ]
// }))
// .unwrap();
// }
// #[test]
// fn parse_minimal_rtu_connect_config() {
// let result = serde_json::from_value::<Connect>(json!({
// "proto": "rtu",
// "tty": "/dev/ttyUSB0",
// "baud_rate": 9600,
// }));
// let connect = result.unwrap();
// use tokio_serial::*;
// assert!(matches!(
// connect.settings,
// ModbusProto::Rtu {
// ref tty,
// baud_rate: 9600,
// data_bits: DataBits::Eight,
// stop_bits: StopBits::One,
// flow_control: FlowControl::None,
// parity: Parity::None,
// ..
// } if tty == "/dev/ttyUSB0"
// ))
// }
// #[test]
// fn parse_complete_rtu_connect_config() {
// let result = serde_json::from_value::<Connect>(json!({
// "proto": "rtu",
// "tty": "/dev/ttyUSB0",
// "baud_rate": 12800,
// // TODO: make lowercase words work
// "data_bits": "Seven", // TODO: make 7 work
// "stop_bits": "Two", // TODO: make 2 work
// "flow_control": "Software",
// "parity": "Even",
// }));
// let connect = result.unwrap();
// use tokio_serial::*;
// assert!(matches!(
// connect.settings,
// ModbusProto::Rtu {
// ref tty,
// baud_rate: 12800,
// data_bits: DataBits::Seven,
// stop_bits: StopBits::Two,
// flow_control: FlowControl::Software,
// parity: Parity::Even,
// ..
// } if tty == "/dev/ttyUSB0"
// ),);
// }
// #[test]
// fn parse_empty_register_parser_defaults() {
// let empty = serde_json::from_value::<RegisterParse>(json!({}));
// assert!(matches!(
// empty.unwrap(),
// RegisterParse {
// swap_bytes: Swap(false),
// swap_words: Swap(false),
// value_type: RegisterValueType::Numeric {
// of: RegisterNumeric::U16,
// adjust: RegisterNumericAdjustment {
// scale: 0,
// offset: 0,
// }
// }
// }
// ));
// }
// #[test]
// fn parse_register_parser_type() {
// let result = serde_json::from_value::<RegisterParse>(json!({
// "type": "s32"
// }));
// assert!(matches!(
// result.unwrap().value_type,
// RegisterValueType::Numeric {
// of: RegisterNumeric::I32,
// ..
// }
// ));
// }
// #[test]
// fn parse_register_parser_array() {
// let result = serde_json::from_value::<RegisterParse>(json!({
// "type": "array",
// "of": "s32",
// "count": 10,
// }));
// let payload = result.unwrap();
// // println!("{:?}", payload);
// // println!("{}", serde_json::to_string_pretty(&payload).unwrap());
// assert!(matches!(
// payload.value_type,
// RegisterValueType::Array(RegisterArray {
// of: RegisterNumeric::I32,
// count: 10,
// ..
// })
// ));
// }
// #[test]
// fn parse_register_parser_array_implicit_u16() {
// let result = serde_json::from_value::<RegisterParse>(json!({
// "type": "array",
// "count": 10,
// }));
// let payload = result.unwrap();
// // println!("{:?}", payload);
// // println!("{}", serde_json::to_string_pretty(&payload).unwrap());
// assert!(matches!(
// payload.value_type,
// RegisterValueType::Array(RegisterArray {
// of: RegisterNumeric::U16,
// count: 10,
// ..
// })
// ));
// }
// #[test]
// fn parse_register_parser_string() {
// let result = serde_json::from_value::<RegisterParse>(json!({
// "type": "string",
// "length": 10,
// }));
// let payload = result.unwrap();
// // println!("{:?}", payload);
// // println!("{}", serde_json::to_string_pretty(&payload).unwrap());
// assert!(matches!(
// payload.value_type,
// RegisterValueType::String(RegisterString { length: 10, .. })
// ));
// }
// #[test]
// fn parse_register_parser_scale_etc() {
// let result = serde_json::from_value::<RegisterParse>(json!({
// "type": "s32",
// "scale": -1,
// "offset": 20,
// }));
// assert!(matches!(
// result.unwrap().value_type,
// RegisterValueType::Numeric {
// of: RegisterNumeric::I32,
// adjust: RegisterNumericAdjustment {
// scale: -1,
// offset: 20
// }
// }
// ));
// }

View File

@ -7,15 +7,21 @@ use rumqttc::{
};
use tokio::{
select,
sync::mpsc::{channel, Receiver, Sender},
sync::mpsc::{self, channel, Receiver, Sender},
};
use tracing::{debug, warn};
use tracing::{debug, info, warn};
use crate::shutdown::Shutdown;
#[derive(Debug)]
pub struct Payload {
pub bytes: Bytes,
pub topic: String,
}
#[derive(Debug, Clone)]
pub enum Message {
Subscribe(Subscribe, Sender<Bytes>),
Subscribe(Subscribe, Sender<Payload>),
Publish(Publish),
Shutdown,
}
@ -37,7 +43,7 @@ pub(crate) async fn new(options: MqttOptions, shutdown: Shutdown) -> Connection
// Maintain internal subscriptions as well as MQTT subscriptions. Relay all received messages on MQTT subscribed topics
// to internal components who have a matching topic. Unsubscribe topics when no one is listening anymore.
pub(crate) struct Connection {
subscriptions: HashMap<String, Vec<Sender<Bytes>>>,
subscriptions: HashMap<String, Vec<Sender<Payload>>>,
tx: Sender<Message>,
rx: Receiver<Message>,
client: AsyncClient,
@ -50,10 +56,7 @@ impl Connection {
loop {
select! {
event = self.event_loop.poll() => {
match event {
Ok(event) => self.handle_event(event).await?,
_ => todo!()
}
self.handle_event(event?).await?
}
request = self.rx.recv() => {
match request {
@ -62,41 +65,33 @@ impl Connection {
Some(req) => self.handle_request(req).await?,
}
}
_ = self.shutdown.recv() => return Ok(())
_ = self.shutdown.recv() => {
info!("MQTT connection shutting down");
break;
}
}
}
Ok(())
}
/// Create a handle for interacting with the MQTT server such that a pre-provided prefix is transparently added to
/// all relevant commands which use a topic.
pub fn prefixed_handle<S: Into<String> + Send>(
&self,
prefix: S,
) -> crate::Result<Sender<Message>> {
pub fn prefixed_handle<S: Into<String> + Send>(&self, prefix: S) -> crate::Result<Handle> {
let prefix = prefix.into();
if !valid_topic(&prefix) {
return Err("Prefix is not a valid topic".into());
}
let inner_tx = self.handle();
let (wrapper_tx, mut wrapper_rx) = channel::<Message>(8);
let prefix: String = prefix.into();
tokio::spawn(async move {
while let Some(msg) = wrapper_rx.recv().await {
if inner_tx.send(msg.prefixed(prefix.clone())).await.is_err() {
break;
}
}
});
Ok(wrapper_tx)
Ok(self.handle().scoped(prefix))
}
pub fn handle(&self) -> Sender<Message> {
self.tx.clone()
pub fn handle(&self) -> Handle {
Handle {
prefix: None,
tx: self.tx.clone(),
}
}
async fn handle_event(&mut self, event: Event) -> crate::Result<()> {
@ -116,7 +111,7 @@ impl Connection {
}
#[tracing::instrument(level = "debug", skip(self), fields(subscriptions = ?self.subscriptions.keys()))]
async fn handle_data(&mut self, topic: String, payload: Bytes) -> crate::Result<()> {
async fn handle_data(&mut self, topic: String, bytes: Bytes) -> crate::Result<()> {
let mut targets = vec![];
// Remove subscriptions whose channels are closed, adding matching channels to the `targets` vec.
@ -138,7 +133,14 @@ impl Connection {
});
for target in targets {
if target.send(payload.clone()).await.is_err() {
if target
.send(Payload {
topic: topic.clone(),
bytes: bytes.clone(),
})
.await
.is_err()
{
// These will be removed above next time a matching payload is removed
}
}
@ -179,47 +181,122 @@ impl Connection {
}
}
trait Prefixable {
fn prefixed<S: Into<String>>(self, prefix: S) -> Self;
#[derive(Clone)]
pub struct Handle {
prefix: Option<String>,
tx: Sender<Message>,
}
impl Prefixable for Message {
fn prefixed<S: Into<String>>(self, prefix: S) -> Self {
impl Handle {
pub async fn subscribe<S: Into<String>>(&self, topic: S) -> crate::Result<Receiver<Payload>> {
let (tx_bytes, rx) = mpsc::channel(8);
let mut msg =
Message::Subscribe(Subscribe::new(topic, rumqttc::QoS::AtLeastOnce), tx_bytes);
if let Some(prefix) = &self.prefix {
msg = msg.scoped(prefix.to_owned());
}
self.tx
.send(msg)
.await
.map_err(|_| crate::Error::SendError)?;
Ok(rx)
}
pub async fn publish<S: Into<String>, B: Into<Bytes>>(
&self,
topic: S,
payload: B,
) -> crate::Result<()> {
let mut msg = Message::Publish(Publish::new(
topic,
rumqttc::QoS::AtLeastOnce,
payload.into(),
));
if let Some(prefix) = &self.prefix {
msg = msg.scoped(prefix.to_owned());
}
self.tx
.send(msg)
.await
.map_err(|_| crate::Error::SendError)?;
Ok(())
}
}
pub(crate) trait Scopable {
fn scoped<S: Into<String>>(&self, prefix: S) -> Self;
}
// FIXME: this doesn't actually _prefix_ it _appends_ to the existing prefix, so there's probably a better name for this
// trait, like: Scopable
impl Scopable for Handle {
fn scoped<S: Into<String>>(&self, prefix: S) -> Self {
match self {
Message::Subscribe(sub, bytes) => Message::Subscribe(sub.prefixed(prefix), bytes),
Message::Publish(publish) => Message::Publish(publish.prefixed(prefix)),
other => other,
Self { prefix: None, tx } => Self {
prefix: Some(prefix.into()),
tx: tx.clone(),
},
Self {
prefix: Some(existing),
tx,
} => Self {
prefix: Some(format!("{}/{}", existing, prefix.into())),
tx: tx.clone(),
},
}
}
}
impl Prefixable for Subscribe {
fn prefixed<S: Into<String>>(mut self, prefix: S) -> Self {
impl Scopable for Message {
fn scoped<S: Into<String>>(&self, prefix: S) -> Self {
match self {
Message::Subscribe(sub, bytes) => Message::Subscribe(sub.scoped(prefix), bytes.clone()),
Message::Publish(publish) => Message::Publish(publish.scoped(prefix)),
other => (*other).clone(),
}
}
}
impl Scopable for Subscribe {
fn scoped<S: Into<String>>(&self, prefix: S) -> Self {
let prefix: String = prefix.into();
Self {
pkid: self.pkid,
filters: self
.filters
.iter_mut()
.map(|f| f.clone().prefixed(prefix.clone()))
.iter()
.map(|f| f.clone().scoped(prefix.clone()))
.collect(),
}
}
}
impl Prefixable for Publish {
fn prefixed<S: Into<String>>(self, prefix: S) -> Self {
impl Scopable for Publish {
fn scoped<S: Into<String>>(&self, prefix: S) -> Self {
let mut prefixed = self.clone();
prefixed.topic = format!("{}/{}", prefix.into(), &self.topic);
prefixed
}
}
impl Prefixable for SubscribeFilter {
fn prefixed<S: Into<String>>(self, prefix: S) -> Self {
impl Scopable for SubscribeFilter {
fn scoped<S: Into<String>>(&self, prefix: S) -> Self {
SubscribeFilter {
path: format!("{}/{}", prefix.into(), &self.path),
qos: self.qos,
}
}
}
impl From<Payload> for Bytes {
fn from(payload: Payload) -> Self {
payload.bytes
}
}
impl std::ops::Deref for Payload {
type Target = Bytes;
fn deref(&self) -> &Self::Target {
&self.bytes
}
}

View File

@ -1,90 +1,69 @@
use crate::mqtt;
use crate::{modbus, mqtt};
use rumqttc::MqttOptions;
use std::future::Future;
use tokio::sync::broadcast;
use tracing::{debug, error, info};
use std::{future::Future, time::Duration};
use tokio::{
sync::{broadcast, mpsc},
time::timeout,
};
use tracing::{error, info};
pub struct Server {
notify_shutdown: broadcast::Sender<()>,
mqtt_connection: mqtt::Connection,
}
pub async fn run<P: Into<String>>(
pub async fn run<P: Into<String> + Send>(
prefix: P,
mqtt_options: MqttOptions,
mut mqtt_options: MqttOptions,
shutdown: impl Future,
) -> crate::Result<()> {
let prefix = prefix.into();
let (notify_shutdown, _) = broadcast::channel(1);
let mqtt_connection = mqtt::new(mqtt_options, notify_shutdown.subscribe().into()).await;
let (shutdown_complete_tx, mut shutdown_complete_rx) = mpsc::channel(1);
let mut server = Server {
notify_shutdown,
mqtt_connection,
};
// TODO: make sure mqtt connection is last thing to shutdown, so other components can send final messages.
mqtt_options.set_last_will(rumqttc::LastWill {
topic: prefix.clone(),
message: "offline".into(),
qos: rumqttc::QoS::AtMostOnce,
retain: false,
});
let mut mqtt_connection = mqtt::new(
mqtt_options,
(notify_shutdown.subscribe(), shutdown_complete_tx.clone()).into(),
)
.await;
mqtt_connection
.handle()
.publish(prefix.clone(), "online")
.await?;
let mqtt = mqtt_connection.prefixed_handle(prefix)?;
let mut ret = Ok(());
let mut connector = modbus::connector::new(
mqtt.clone(),
(notify_shutdown.subscribe(), shutdown_complete_tx.clone()).into(),
);
tokio::select! {
res = server.run() => {
if let Err(err) = res {
error!(cause = %err, "server error");
ret = Err(err)
} else {
info!("server finished running")
}
tokio::spawn(async move {
if let Err(err) = mqtt_connection.run().await {
error!(cause = %err, "MQTT connection error");
}
});
_ = shutdown => {
info!("shutting down");
tokio::spawn(async move {
if let Err(err) = connector.run().await {
error!(cause = %err, "Modbus connector error");
}
}
let Server {
notify_shutdown, ..
} = server;
});
shutdown.await;
drop(notify_shutdown);
drop(shutdown_complete_tx);
ret
}
impl Server {
async fn run(&mut self) -> crate::Result<()> {
info!("Starting up");
let tx = self.mqtt_connection.prefixed_handle("hello")?;
{
let tx = tx.clone();
tokio::spawn(async move {
let mut interval = tokio::time::interval(std::time::Duration::from_secs(1));
loop {
interval.tick().await;
tx.send(mqtt::Message::Publish(rumqttc::Publish::new(
"foo/bar/baz",
rumqttc::QoS::AtLeastOnce,
"hello",
)))
.await
.unwrap();
}
});
}
tokio::spawn(async move {
let (tx_bytes, mut rx) = tokio::sync::mpsc::channel(32);
tx.send(mqtt::Message::Subscribe(
rumqttc::Subscribe::new("foo/+/baz", rumqttc::QoS::AtLeastOnce),
tx_bytes,
))
.await
.unwrap();
while let Some(bytes) = rx.recv().await {
debug!(?bytes, "received");
}
});
self.mqtt_connection.run().await
}
timeout(Duration::from_secs(5), shutdown_complete_rx.recv())
.await
.map_err(|_| {
crate::Error::Other("Shutdown didn't complete within 5 seconds; aborting".into())
})?;
info!("Shutdown.");
Ok(())
}

View File

@ -1,7 +1,7 @@
//! **Note**: this is a barely modified copy of the code which appears in mini-redis
type Notify = tokio::sync::broadcast::Receiver<()>;
type Guard = tokio::sync::mpsc::Sender<()>;
/// Listens for the server shutdown signal.
///
/// Shutdown is signalled using a `broadcast::Receiver`. Only a single value is
@ -19,6 +19,20 @@ pub(crate) struct Shutdown {
/// The receive half of the channel used to listen for shutdown.
notify: Notify,
/// Optional guard as a sender so that when the `Shutdown` struct is dropped, the other side of the channel is
/// closed.
guard: Option<Guard>,
}
impl Clone for Shutdown {
fn clone(&self) -> Self {
Self {
shutdown: self.shutdown,
notify: self.notify.resubscribe(),
guard: self.guard.clone(),
}
}
}
impl Shutdown {
@ -27,6 +41,15 @@ impl Shutdown {
Shutdown {
shutdown: false,
notify,
guard: None,
}
}
/// Create a new `Shutdown` backed by the given `broadcast::Receiver` with a given guard.
pub(crate) fn with_guard(notify: Notify, guard: Guard) -> Shutdown {
Shutdown {
shutdown: false,
notify,
guard: Some(guard),
}
}
@ -56,3 +79,14 @@ impl From<Notify> for Shutdown {
Self::new(notify)
}
}
impl From<(Notify, Guard)> for Shutdown {
fn from((notify, guard): (Notify, Guard)) -> Self {
Self::with_guard(notify, guard)
}
}
impl From<(Guard, Notify)> for Shutdown {
fn from((guard, notify): (Guard, Notify)) -> Self {
Self::with_guard(notify, guard)
}
}