diff --git a/Cargo.lock b/Cargo.lock index 3fa9f73..385c40c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -67,6 +67,34 @@ dependencies = [ "syn", ] +[[package]] +name = "async-tungstenite" +version = "0.16.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5682ea0913e5c20780fe5785abacb85a411e7437bf52a1bedb93ddb3972cb8dd" +dependencies = [ + "futures-io", + "futures-util", + "log", + "pin-project-lite", + "rustls-native-certs", + "tokio", + "tokio-rustls", + "tungstenite 0.16.0", +] + +[[package]] +name = "async_io_stream" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6d7b9decdf35d8908a7e3ef02f64c5e9b1695e230154c0e8de3969142d9b94c" +dependencies = [ + "futures", + "pharos", + "rustc_version", + "tokio", +] + [[package]] name = "atty" version = "0.2.14" @@ -106,6 +134,15 @@ dependencies = [ "syn", ] +[[package]] +name = "block-buffer" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4152116fd6e9dadb291ae18fc1ec3575ed6d84c29642d97890f4b4a3417297e4" +dependencies = [ + "generic-array", +] + [[package]] name = "block-buffer" version = "0.10.2" @@ -247,13 +284,22 @@ version = "0.1.13" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "56254986775e3233ffa9c4d7d3faaf6d36a2c09d30b20687e9f88bc8bafc16c8" +[[package]] +name = "digest" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3dd60d1080a57a05ab032377049e0591415d2b31afd7028356dbf3cc6dcb066" +dependencies = [ + "generic-array", +] + [[package]] name = "digest" version = "0.10.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f2fb860ca6fafa5552fb6d0e816a69c8e49f0908bf524e30a90d97c85892d506" dependencies = [ - "block-buffer", + "block-buffer 0.10.2", "crypto-common", ] @@ -739,6 +785,7 @@ dependencies = [ "tokio_modbus-winets", "tracing", "tracing-subscriber", + "url", ] [[package]] @@ -809,6 +856,12 @@ version = "1.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "074864da206b4973b84eb91683020dbefd6a8c3f0f38e054d93954e891935e4e" +[[package]] +name = "opaque-debug" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "624a8340c38c1b80fd549087862da4ba43e08858af025b236e509b6649fc13d5" + [[package]] name = "openssl-probe" version = "0.1.5" @@ -836,6 +889,16 @@ version = "2.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d4fd5641d01c8f18a23da7b6fe29298ff4b55afcccdf78973b24cf3175fee32e" +[[package]] +name = "pharos" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e9567389417feee6ce15dd6527a8a1ecac205ef62c2932bcf3d9f6fc5b78b414" +dependencies = [ + "futures", + "rustc_version", +] + [[package]] name = "pin-project" version = "1.0.12" @@ -1042,6 +1105,7 @@ version = "0.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bfc3304ef531b4ff76b1997b6983475ac092ca94f450c88388fa2a8f4dd80bb1" dependencies = [ + "async-tungstenite", "bytes", "flume", "http", @@ -1052,6 +1116,8 @@ dependencies = [ "thiserror", "tokio", "tokio-rustls", + "url", + "ws_stream_tungstenite", ] [[package]] @@ -1066,6 +1132,15 @@ dependencies = [ "serde_json", ] +[[package]] +name = "rustc_version" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfa0f585226d2e68097d4f95d113b15b83a82e819ab25717ec0590d9584ef366" +dependencies = [ + "semver", +] + [[package]] name = "rustls" version = "0.20.6" @@ -1163,6 +1238,12 @@ dependencies = [ "libc", ] +[[package]] +name = "semver" +version = "1.0.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93f6841e709003d68bb2deee8c343572bf446003ec20a583e76f7b15cebf3711" + [[package]] name = "serde" version = "1.0.144" @@ -1235,6 +1316,19 @@ dependencies = [ "winapi", ] +[[package]] +name = "sha-1" +version = "0.9.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "99cd6713db3cf16b6c84e06321e049a9b9f699826e16096d23bbcc44d15d51a6" +dependencies = [ + "block-buffer 0.9.0", + "cfg-if", + "cpufeatures", + "digest 0.9.0", + "opaque-debug", +] + [[package]] name = "sha-1" version = "0.10.0" @@ -1243,7 +1337,7 @@ checksum = "028f48d513f9678cda28f6e4064755b3fbb2af6acd672f2c209b62323f7aea0f" dependencies = [ "cfg-if", "cpufeatures", - "digest", + "digest 0.10.3", ] [[package]] @@ -1255,6 +1349,15 @@ dependencies = [ "lazy_static", ] +[[package]] +name = "signal-hook-registry" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e51e73328dc4ac0c7ccbda3a494dfa03df1de2f46018127f60c693f2648455b0" +dependencies = [ + "libc", +] + [[package]] name = "slab" version = "0.4.7" @@ -1317,7 +1420,7 @@ dependencies = [ "tokio-tungstenite", "tracing", "tracing-subscriber", - "tungstenite", + "tungstenite 0.17.3", ] [[package]] @@ -1404,6 +1507,7 @@ dependencies = [ "num_cpus", "once_cell", "pin-project-lite", + "signal-hook-registry", "socket2", "tokio-macros", "winapi", @@ -1470,7 +1574,7 @@ dependencies = [ "futures-util", "log", "tokio", - "tungstenite", + "tungstenite 0.17.3", ] [[package]] @@ -1567,6 +1671,27 @@ version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "59547bce71d9c38b83d9c0e92b6066c4253371f15005def0c30d9657f50c7642" +[[package]] +name = "tungstenite" +version = "0.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ad3713a14ae247f22a728a0456a545df14acf3867f905adff84be99e23b3ad1" +dependencies = [ + "base64", + "byteorder", + "bytes", + "http", + "httparse", + "log", + "rand", + "rustls", + "sha-1 0.9.8", + "thiserror", + "url", + "utf-8", + "webpki", +] + [[package]] name = "tungstenite" version = "0.17.3" @@ -1580,7 +1705,7 @@ dependencies = [ "httparse", "log", "rand", - "sha-1", + "sha-1 0.10.0", "thiserror", "url", "utf-8", @@ -1629,6 +1754,7 @@ dependencies = [ "idna", "matches", "percent-encoding", + "serde", ] [[package]] @@ -1833,3 +1959,23 @@ checksum = "80d0f4e272c85def139476380b12f9ac60926689dd2e01d4923222f40580869d" dependencies = [ "winapi", ] + +[[package]] +name = "ws_stream_tungstenite" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a672ec78525bf189cefa7f1b72c55f928b3edbdb967e680ca49748ab20821045" +dependencies = [ + "async-tungstenite", + "async_io_stream", + "bitflags", + "futures-core", + "futures-io", + "futures-sink", + "futures-util", + "log", + "pharos", + "rustc_version", + "tokio", + "tungstenite 0.16.0", +] diff --git a/modbus-mqtt/Cargo.toml b/modbus-mqtt/Cargo.toml index ead08cf..c3a3e81 100644 --- a/modbus-mqtt/Cargo.toml +++ b/modbus-mqtt/Cargo.toml @@ -15,18 +15,20 @@ bytes = "1.1.0" clap = { version = "3.2.12", features = ["derive", "env"] } humantime-serde = "1.1.1" itertools = "0.10.3" -rumqttc = "0.15.0" + +rumqttc = { version = "0.15.0", default-features = true, features = ["url"] } # https://github.com/bytebeamio/rumqtt/issues/446 rust_decimal = { version = "1.26.1", features = ["serde-arbitrary-precision", "serde-float", "serde_json", "maths"] } serde = { version = "1.0.139", features = ["serde_derive"] } serde_json = "1.0.82" serialport = { version = "4.2.0", optional = true, features = ["serde"] } thiserror = "1.0.33" -tokio = { version = "1.20.0", features = ["rt", "rt-multi-thread", "time"] } +tokio = { version = "1.20.0", features = ["rt", "rt-multi-thread", "time", "signal"] } tokio-modbus = { version = "0.5.3", default-features = false } tokio-serial = { version = "5.4.3", optional = true } tokio_modbus-winets = { version = "0.1.0", path = "../tokio_modbus-winets", optional = true, default-features = false } tracing = "0.1.36" tracing-subscriber = "0.3.15" +url = { version = "2.2.2", features = ["serde"] } [dev-dependencies] pretty_assertions = "1.2.1" @@ -36,3 +38,6 @@ default = ["tcp", "rtu", "winet-s"] tcp = ["tokio-modbus/tcp"] rtu = ["tokio-modbus/rtu", "dep:tokio-serial", "dep:serialport"] winet-s = ["dep:tokio_modbus-winets"] +ws = ["rumqttc/websocket"] +tls = ["rustls"] +rustls = ["rumqttc/use-rustls"] diff --git a/modbus-mqtt/src/bin/run.rs b/modbus-mqtt/src/bin/run.rs new file mode 100644 index 0000000..cadd86c --- /dev/null +++ b/modbus-mqtt/src/bin/run.rs @@ -0,0 +1,45 @@ +use clap::Parser; +use modbus_mqtt::{server, Result}; +use url::Url; + +#[derive(Parser, Debug)] +#[clap( + name = "modbus-mqtt", + version, + author, + about = "A bridge between Modbus and MQTT" +)] +struct Cli { + #[clap( + env = "MQTT_URL", + // validator = "is_mqtt_url", + default_value = "mqtt://localhost:1883/modbus-mqtt", + value_hint = clap::ValueHint::Url + )] + url: Url, +} + +#[tokio::main] +async fn main() -> Result<()> { + tracing_subscriber::fmt::init(); + let mut args = Cli::parse(); + + let prefix = args + .url + .path() + .trim_start_matches('/') + .split('/') + .next() + .unwrap_or(env!("CARGO_PKG_NAME")) + .to_owned(); + + // FIXME: if they pass "?client_id=foo" param, skip this + args.url + .query_pairs_mut() + .append_pair("client_id", env!("CARGO_PKG_NAME")) + .finish(); + + server::run(prefix, args.url.try_into()?, tokio::signal::ctrl_c()).await?; + + Ok(()) +} diff --git a/modbus-mqtt/src/connection.rs b/modbus-mqtt/src/connection.rs deleted file mode 100644 index bd10053..0000000 --- a/modbus-mqtt/src/connection.rs +++ /dev/null @@ -1,4 +0,0 @@ -pub struct Connection { - // connect: Connect, - context: tokio_modbus::client::Context, -} diff --git a/modbus-mqtt/src/homeassistant.rs b/modbus-mqtt/src/homeassistant.rs new file mode 100644 index 0000000..588d007 --- /dev/null +++ b/modbus-mqtt/src/homeassistant.rs @@ -0,0 +1,8 @@ +use tokio::sync::mpsc::Sender; + +use crate::{modbus::register::Register, mqtt}; + +/// Describes the register to Home Assistant +fn configure(register: Register, tx: Sender) -> crate::Result<()> { + Ok(()) +} diff --git a/modbus-mqtt/src/main.rs b/modbus-mqtt/src/lib.rs similarity index 88% rename from modbus-mqtt/src/main.rs rename to modbus-mqtt/src/lib.rs index e670cb5..a574179 100644 --- a/modbus-mqtt/src/main.rs +++ b/modbus-mqtt/src/lib.rs @@ -8,7 +8,13 @@ use tracing::{debug, error, info}; use thiserror::Error; -mod modbus; +mod shutdown; +pub(crate) use shutdown::Shutdown; + +pub mod homeassistant; +pub mod modbus; +pub mod mqtt; +pub mod server; #[derive(Error, Debug)] #[non_exhaustive] @@ -16,38 +22,31 @@ pub enum Error { #[error(transparent)] IOError(#[from] std::io::Error), + #[error("{0}")] + Other(std::borrow::Cow<'static, str>), + + #[error(transparent)] + MQTTOptionError(#[from] rumqttc::OptionError), + + #[error(transparent)] + MQTTClientError(#[from] rumqttc::ClientError), + #[error("Unknown")] Unknown, } -type Result = std::result::Result; - -#[derive(clap::Parser)] -#[clap( - name = "modbus-mqtt", - version, - author, - about = "A bridge between Modbus and MQTT" -)] -struct Cli { - mqtt_host: String, - - #[clap(short = 'n', long, default_value = "modbus")] - mqtt_name: String, - - #[clap(short = 'p', long, default_value_t = 1883)] - mqtt_port: u16, - - #[clap(short = 'u', long, env = "MQTT_USER")] - mqtt_user: Option, - - #[clap(short = 'P', long, env)] - mqtt_password: Option, - - #[clap(short = 't', long, default_value = "modbus-mqtt")] - // Where to listen for commands - mqtt_topic_prefix: String, +impl From for Error { + fn from(s: String) -> Self { + Self::Other(s.into()) + } } +impl From<&'static str> for Error { + fn from(s: &'static str) -> Self { + Self::Other(s.into()) + } +} + +pub type Result = std::result::Result; #[derive(Serialize)] #[serde(rename_all = "lowercase")] @@ -56,43 +55,6 @@ enum MainStatus { Stopped, } -#[tokio::main(worker_threads = 3)] -async fn main() -> Result<()> { - tracing_subscriber::fmt::init(); - - use clap::Parser; - let args = Cli::parse(); - - let (registry_tx, registry_rx) = mpsc::channel::(32); - let (dispatcher_tx, dispatcher_rx) = mpsc::channel::(32); - - // Modbus connection registry - let registry_handle = { - let prefix = args.mqtt_topic_prefix.clone(); - tokio::spawn(connection_registry(prefix, dispatcher_tx, registry_rx)) - }; - - // MQTT Dispatcher - let dispatcher_handle = { - let prefix = args.mqtt_topic_prefix.clone(); - let mut options = MqttOptions::new( - env!("CARGO_PKG_NAME"), - args.mqtt_host.as_str(), - args.mqtt_port, - ); - if let (Some(u), Some(p)) = (args.mqtt_user, args.mqtt_password) { - options.set_credentials(u, p); - } - options.set_keep_alive(Duration::from_secs(5)); // TODO: make this configurable - - tokio::spawn(mqtt_dispatcher(options, prefix, registry_tx, dispatcher_rx)) - }; - - registry_handle.await.unwrap(); - dispatcher_handle.await.unwrap(); - Ok(()) -} - #[derive(Debug)] enum DispatchCommand { Publish { topic: String, payload: Vec }, diff --git a/modbus-mqtt/src/modbus/connection.rs b/modbus-mqtt/src/modbus/connection.rs new file mode 100644 index 0000000..195f86f --- /dev/null +++ b/modbus-mqtt/src/modbus/connection.rs @@ -0,0 +1 @@ +pub struct Connection {} diff --git a/modbus-mqtt/src/modbus/mod.rs b/modbus-mqtt/src/modbus/mod.rs index 644cc08..2f8aa89 100644 --- a/modbus-mqtt/src/modbus/mod.rs +++ b/modbus-mqtt/src/modbus/mod.rs @@ -4,6 +4,11 @@ use serde::Serialize; use self::config::{Register, RegisterValueType}; pub mod config; +pub mod connection; +pub mod register; + +pub use connection::Connection; +// pub use register::Register; #[derive(Serialize)] #[serde(rename_all = "lowercase")] diff --git a/modbus-mqtt/src/modbus/register.rs b/modbus-mqtt/src/modbus/register.rs new file mode 100644 index 0000000..59274bd --- /dev/null +++ b/modbus-mqtt/src/modbus/register.rs @@ -0,0 +1 @@ +pub struct Register {} diff --git a/modbus-mqtt/src/mqtt.rs b/modbus-mqtt/src/mqtt.rs new file mode 100644 index 0000000..10beac7 --- /dev/null +++ b/modbus-mqtt/src/mqtt.rs @@ -0,0 +1,225 @@ +use std::collections::HashMap; + +use bytes::Bytes; +use rumqttc::{ + mqttbytes::matches as matches_topic, mqttbytes::valid_topic, AsyncClient, Event, EventLoop, + MqttOptions, Publish, Subscribe, SubscribeFilter, +}; +use tokio::{ + select, + sync::mpsc::{channel, Receiver, Sender}, +}; +use tracing::{debug, warn}; + +use crate::shutdown::Shutdown; + +#[derive(Debug)] +pub enum Message { + Subscribe(Subscribe, Sender), + Publish(Publish), + Shutdown, +} + +pub(crate) async fn new(options: MqttOptions, shutdown: Shutdown) -> Connection { + let (client, event_loop) = AsyncClient::new(options, 32); + + let (tx, rx) = channel(32); + Connection { + client, + event_loop, + subscriptions: HashMap::new(), + tx, + rx, + shutdown, + } +} + +// Maintain internal subscriptions as well as MQTT subscriptions. Relay all received messages on MQTT subscribed topics +// to internal components who have a matching topic. Unsubscribe topics when no one is listening anymore. +pub(crate) struct Connection { + subscriptions: HashMap>>, + tx: Sender, + rx: Receiver, + client: AsyncClient, + event_loop: EventLoop, + shutdown: Shutdown, +} + +impl Connection { + pub async fn run(&mut self) -> crate::Result<()> { + loop { + select! { + event = self.event_loop.poll() => { + match event { + Ok(event) => self.handle_event(event).await?, + _ => todo!() + } + } + request = self.rx.recv() => { + match request { + None => return Ok(()), + Some(Message::Shutdown) => return Ok(()), + Some(req) => self.handle_request(req).await?, + } + } + _ = self.shutdown.recv() => return Ok(()) + } + } + } + + /// Create a handle for interacting with the MQTT server such that a pre-provided prefix is transparently added to + /// all relevant commands which use a topic. + pub fn prefixed_handle + Send>( + &self, + prefix: S, + ) -> crate::Result> { + let prefix = prefix.into(); + + if !valid_topic(&prefix) { + return Err("Prefix is not a valid topic".into()); + } + + let inner_tx = self.handle(); + let (wrapper_tx, mut wrapper_rx) = channel::(8); + + let prefix: String = prefix.into(); + + tokio::spawn(async move { + while let Some(msg) = wrapper_rx.recv().await { + if inner_tx.send(msg.prefixed(prefix.clone())).await.is_err() { + break; + } + } + }); + + Ok(wrapper_tx) + } + + pub fn handle(&self) -> Sender { + self.tx.clone() + } + + async fn handle_event(&mut self, event: Event) -> crate::Result<()> { + use rumqttc::Incoming; + + #[allow(clippy::single_match)] + match event { + Event::Incoming(Incoming::Publish(Publish { topic, payload, .. })) => { + debug!(%topic, ?payload, "publish"); + self.handle_data(topic, payload).await?; + } + // e => debug!(event = ?e), + _ => {} + } + + Ok(()) + } + + #[tracing::instrument(level = "debug", skip(self), fields(subscriptions = ?self.subscriptions.keys()))] + async fn handle_data(&mut self, topic: String, payload: Bytes) -> crate::Result<()> { + let mut targets = vec![]; + + // Remove subscriptions whose channels are closed, adding matching channels to the `targets` vec. + self.subscriptions.retain(|filter, channels| { + if matches_topic(&topic, filter) { + channels.retain(|channel| { + if channel.is_closed() { + warn!(?channel, "closed"); + false + } else { + targets.push(channel.clone()); + true + } + }); + !channels.is_empty() + } else { + true + } + }); + + for target in targets { + if target.send(payload.clone()).await.is_err() { + // These will be removed above next time a matching payload is removed + } + } + Ok(()) + } + + async fn handle_request(&mut self, request: Message) -> crate::Result<()> { + match request { + Message::Publish(Publish { + topic, + payload, + qos, + retain, + .. + }) => { + self.client + .publish_bytes(topic, qos, retain, payload) + .await? + } + Message::Subscribe(Subscribe { filters, .. }, channel) => { + for filter in &filters { + let channel = channel.clone(); + + match self.subscriptions.get_mut(&filter.path) { + Some(channels) => channels.push(channel), + None => { + self.subscriptions + .insert(filter.path.clone(), vec![channel]); + } + } + } + + self.client.subscribe_many(filters).await? + } + Message::Shutdown => panic!("Handled by the caller"), + } + Ok(()) + } +} + +trait Prefixable { + fn prefixed>(self, prefix: S) -> Self; +} + +impl Prefixable for Message { + fn prefixed>(self, prefix: S) -> Self { + match self { + Message::Subscribe(sub, bytes) => Message::Subscribe(sub.prefixed(prefix), bytes), + Message::Publish(publish) => Message::Publish(publish.prefixed(prefix)), + other => other, + } + } +} + +impl Prefixable for Subscribe { + fn prefixed>(mut self, prefix: S) -> Self { + let prefix: String = prefix.into(); + Self { + pkid: self.pkid, + filters: self + .filters + .iter_mut() + .map(|f| f.clone().prefixed(prefix.clone())) + .collect(), + } + } +} + +impl Prefixable for Publish { + fn prefixed>(self, prefix: S) -> Self { + let mut prefixed = self.clone(); + prefixed.topic = format!("{}/{}", prefix.into(), &self.topic); + prefixed + } +} + +impl Prefixable for SubscribeFilter { + fn prefixed>(self, prefix: S) -> Self { + SubscribeFilter { + path: format!("{}/{}", prefix.into(), &self.path), + qos: self.qos, + } + } +} diff --git a/modbus-mqtt/src/server.rs b/modbus-mqtt/src/server.rs new file mode 100644 index 0000000..3c2f8a1 --- /dev/null +++ b/modbus-mqtt/src/server.rs @@ -0,0 +1,90 @@ +use crate::mqtt; +use rumqttc::MqttOptions; +use std::future::Future; +use tokio::sync::broadcast; +use tracing::{debug, error, info}; + +pub struct Server { + notify_shutdown: broadcast::Sender<()>, + mqtt_connection: mqtt::Connection, +} + +pub async fn run>( + prefix: P, + mqtt_options: MqttOptions, + shutdown: impl Future, +) -> crate::Result<()> { + let (notify_shutdown, _) = broadcast::channel(1); + let mqtt_connection = mqtt::new(mqtt_options, notify_shutdown.subscribe().into()).await; + + let mut server = Server { + notify_shutdown, + mqtt_connection, + }; + + let mut ret = Ok(()); + + tokio::select! { + res = server.run() => { + if let Err(err) = res { + error!(cause = %err, "server error"); + ret = Err(err) + } else { + info!("server finished running") + } + } + + _ = shutdown => { + info!("shutting down"); + } + } + + let Server { + notify_shutdown, .. + } = server; + + drop(notify_shutdown); + + ret +} + +impl Server { + async fn run(&mut self) -> crate::Result<()> { + info!("Starting up"); + + let tx = self.mqtt_connection.prefixed_handle("hello")?; + + { + let tx = tx.clone(); + tokio::spawn(async move { + let mut interval = tokio::time::interval(std::time::Duration::from_secs(1)); + loop { + interval.tick().await; + tx.send(mqtt::Message::Publish(rumqttc::Publish::new( + "foo/bar/baz", + rumqttc::QoS::AtLeastOnce, + "hello", + ))) + .await + .unwrap(); + } + }); + } + + tokio::spawn(async move { + let (tx_bytes, mut rx) = tokio::sync::mpsc::channel(32); + tx.send(mqtt::Message::Subscribe( + rumqttc::Subscribe::new("foo/+/baz", rumqttc::QoS::AtLeastOnce), + tx_bytes, + )) + .await + .unwrap(); + + while let Some(bytes) = rx.recv().await { + debug!(?bytes, "received"); + } + }); + + self.mqtt_connection.run().await + } +} diff --git a/modbus-mqtt/src/shutdown.rs b/modbus-mqtt/src/shutdown.rs new file mode 100644 index 0000000..587f78b --- /dev/null +++ b/modbus-mqtt/src/shutdown.rs @@ -0,0 +1,58 @@ +//! **Note**: this is a barely modified copy of the code which appears in mini-redis + +type Notify = tokio::sync::broadcast::Receiver<()>; + +/// Listens for the server shutdown signal. +/// +/// Shutdown is signalled using a `broadcast::Receiver`. Only a single value is +/// ever sent. Once a value has been sent via the broadcast channel, the server +/// should shutdown. +/// +/// The `Shutdown` struct listens for the signal and tracks that the signal has +/// been received. Callers may query for whether the shutdown signal has been +/// received or not. +/// +#[derive(Debug)] +pub(crate) struct Shutdown { + /// `true` if the shutdown signal has been received + shutdown: bool, + + /// The receive half of the channel used to listen for shutdown. + notify: Notify, +} + +impl Shutdown { + /// Create a new `Shutdown` backed by the given `broadcast::Receiver`. + pub(crate) fn new(notify: Notify) -> Shutdown { + Shutdown { + shutdown: false, + notify, + } + } + + /// Returns `true` if the shutdown signal has been received. + pub(crate) fn is_shutdown(&self) -> bool { + self.shutdown + } + + /// Receive the shutdown notice, waiting if necessary. + pub(crate) async fn recv(&mut self) { + // If the shutdown signal has already been received, then return + // immediately. + if self.shutdown { + return; + } + + // Cannot receive a "lag error" as only one value is ever sent. + let _ = self.notify.recv().await; + + // Remember that the signal has been received. + self.shutdown = true; + } +} + +impl From for Shutdown { + fn from(notify: Notify) -> Self { + Self::new(notify) + } +}