From 5dd0e383eaaec36fcb378894b966e098d94fe897 Mon Sep 17 00:00:00 2001 From: Bo Jeanes Date: Fri, 9 Sep 2022 17:01:19 +1000 Subject: [PATCH] Rewrite & refactor (#4) --- Cargo.lock | 178 +++++- README.md | 203 ++++--- modbus-mqtt/Cargo.toml | 29 +- .../examples/sungrow-sh5.0rs-http.json | 19 + modbus-mqtt/examples/sungrow-sh5.0rs.json | 114 ++++ modbus-mqtt/src/bin/run.rs | 80 +++ modbus-mqtt/src/connection.rs | 4 - modbus-mqtt/src/error.rs | 55 ++ modbus-mqtt/src/lib.rs | 10 + modbus-mqtt/src/main.rs | 448 -------------- modbus-mqtt/src/modbus/config.rs | 507 ---------------- modbus-mqtt/src/modbus/connection.rs | 529 +++++++++++++++++ modbus-mqtt/src/modbus/connector.rs | 133 +++++ modbus-mqtt/src/modbus/mod.rs | 151 +---- modbus-mqtt/src/modbus/register.rs | 550 ++++++++++++++++++ modbus-mqtt/src/mqtt.rs | 299 ++++++++++ modbus-mqtt/src/server.rs | 57 ++ modbus-mqtt/src/shutdown.rs | 92 +++ sungrow-winets/src/lib.rs | 5 +- 19 files changed, 2260 insertions(+), 1203 deletions(-) create mode 100644 modbus-mqtt/examples/sungrow-sh5.0rs-http.json create mode 100644 modbus-mqtt/examples/sungrow-sh5.0rs.json create mode 100644 modbus-mqtt/src/bin/run.rs delete mode 100644 modbus-mqtt/src/connection.rs create mode 100644 modbus-mqtt/src/error.rs create mode 100644 modbus-mqtt/src/lib.rs delete mode 100644 modbus-mqtt/src/main.rs delete mode 100644 modbus-mqtt/src/modbus/config.rs create mode 100644 modbus-mqtt/src/modbus/connection.rs create mode 100644 modbus-mqtt/src/modbus/connector.rs create mode 100644 modbus-mqtt/src/modbus/register.rs create mode 100644 modbus-mqtt/src/mqtt.rs create mode 100644 modbus-mqtt/src/server.rs create mode 100644 modbus-mqtt/src/shutdown.rs diff --git a/Cargo.lock b/Cargo.lock index 1712e5c..8741c5e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -67,6 +67,34 @@ dependencies = [ "syn", ] +[[package]] +name = "async-tungstenite" +version = "0.16.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5682ea0913e5c20780fe5785abacb85a411e7437bf52a1bedb93ddb3972cb8dd" +dependencies = [ + "futures-io", + "futures-util", + "log", + "pin-project-lite", + "rustls-native-certs", + "tokio", + "tokio-rustls", + "tungstenite 0.16.0", +] + +[[package]] +name = "async_io_stream" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6d7b9decdf35d8908a7e3ef02f64c5e9b1695e230154c0e8de3969142d9b94c" +dependencies = [ + "futures", + "pharos", + "rustc_version", + "tokio", +] + [[package]] name = "atty" version = "0.2.14" @@ -106,6 +134,15 @@ dependencies = [ "syn", ] +[[package]] +name = "block-buffer" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4152116fd6e9dadb291ae18fc1ec3575ed6d84c29642d97890f4b4a3417297e4" +dependencies = [ + "generic-array", +] + [[package]] name = "block-buffer" version = "0.10.2" @@ -247,13 +284,22 @@ version = "0.1.13" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "56254986775e3233ffa9c4d7d3faaf6d36a2c09d30b20687e9f88bc8bafc16c8" +[[package]] +name = "digest" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3dd60d1080a57a05ab032377049e0591415d2b31afd7028356dbf3cc6dcb066" +dependencies = [ + "generic-array", +] + [[package]] name = "digest" version = "0.10.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f2fb860ca6fafa5552fb6d0e816a69c8e49f0908bf524e30a90d97c85892d506" dependencies = [ - "block-buffer", + "block-buffer 0.10.2", "crypto-common", ] @@ -720,7 +766,7 @@ dependencies = [ [[package]] name = "modbus-mqtt" -version = "0.1.0" +version = "0.2.0" dependencies = [ "bytes", "clap", @@ -732,13 +778,14 @@ dependencies = [ "serde", "serde_json", "serialport", + "thiserror", "tokio", "tokio-modbus", "tokio-serial", "tokio_modbus-winets", "tracing", "tracing-subscriber", - "uuid", + "url", ] [[package]] @@ -809,6 +856,12 @@ version = "1.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "074864da206b4973b84eb91683020dbefd6a8c3f0f38e054d93954e891935e4e" +[[package]] +name = "opaque-debug" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "624a8340c38c1b80fd549087862da4ba43e08858af025b236e509b6649fc13d5" + [[package]] name = "openssl-probe" version = "0.1.5" @@ -836,6 +889,16 @@ version = "2.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d4fd5641d01c8f18a23da7b6fe29298ff4b55afcccdf78973b24cf3175fee32e" +[[package]] +name = "pharos" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e9567389417feee6ce15dd6527a8a1ecac205ef62c2932bcf3d9f6fc5b78b414" +dependencies = [ + "futures", + "rustc_version", +] + [[package]] name = "pin-project" version = "1.0.12" @@ -1042,6 +1105,7 @@ version = "0.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bfc3304ef531b4ff76b1997b6983475ac092ca94f450c88388fa2a8f4dd80bb1" dependencies = [ + "async-tungstenite", "bytes", "flume", "http", @@ -1052,6 +1116,8 @@ dependencies = [ "thiserror", "tokio", "tokio-rustls", + "url", + "ws_stream_tungstenite", ] [[package]] @@ -1066,6 +1132,15 @@ dependencies = [ "serde_json", ] +[[package]] +name = "rustc_version" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfa0f585226d2e68097d4f95d113b15b83a82e819ab25717ec0590d9584ef366" +dependencies = [ + "semver", +] + [[package]] name = "rustls" version = "0.20.6" @@ -1163,6 +1238,12 @@ dependencies = [ "libc", ] +[[package]] +name = "semver" +version = "1.0.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93f6841e709003d68bb2deee8c343572bf446003ec20a583e76f7b15cebf3711" + [[package]] name = "serde" version = "1.0.144" @@ -1235,6 +1316,19 @@ dependencies = [ "winapi", ] +[[package]] +name = "sha-1" +version = "0.9.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "99cd6713db3cf16b6c84e06321e049a9b9f699826e16096d23bbcc44d15d51a6" +dependencies = [ + "block-buffer 0.9.0", + "cfg-if", + "cpufeatures", + "digest 0.9.0", + "opaque-debug", +] + [[package]] name = "sha-1" version = "0.10.0" @@ -1243,7 +1337,7 @@ checksum = "028f48d513f9678cda28f6e4064755b3fbb2af6acd672f2c209b62323f7aea0f" dependencies = [ "cfg-if", "cpufeatures", - "digest", + "digest 0.10.3", ] [[package]] @@ -1255,6 +1349,15 @@ dependencies = [ "lazy_static", ] +[[package]] +name = "signal-hook-registry" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e51e73328dc4ac0c7ccbda3a494dfa03df1de2f46018127f60c693f2648455b0" +dependencies = [ + "libc", +] + [[package]] name = "slab" version = "0.4.7" @@ -1317,7 +1420,7 @@ dependencies = [ "tokio-tungstenite", "tracing", "tracing-subscriber", - "tungstenite", + "tungstenite 0.17.3", ] [[package]] @@ -1348,18 +1451,18 @@ checksum = "b1141d4d61095b28419e22cb0bbf02755f5e54e0526f97f1e3d1d160e60885fb" [[package]] name = "thiserror" -version = "1.0.32" +version = "1.0.33" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f5f6586b7f764adc0231f4c79be7b920e766bb2f3e51b3661cdb263828f19994" +checksum = "3d0a539a918745651435ac7db7a18761589a94cd7e94cd56999f828bf73c8a57" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.32" +version = "1.0.33" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "12bafc5b54507e0149cdf1b145a5d80ab80a90bcd9275df43d4fff68460f6c21" +checksum = "c251e90f708e16c49a16f4917dc2131e75222b72edfa9cb7f7c58ae56aae0c09" dependencies = [ "proc-macro2", "quote", @@ -1404,6 +1507,7 @@ dependencies = [ "num_cpus", "once_cell", "pin-project-lite", + "signal-hook-registry", "socket2", "tokio-macros", "winapi", @@ -1470,7 +1574,7 @@ dependencies = [ "futures-util", "log", "tokio", - "tungstenite", + "tungstenite 0.17.3", ] [[package]] @@ -1567,6 +1671,27 @@ version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "59547bce71d9c38b83d9c0e92b6066c4253371f15005def0c30d9657f50c7642" +[[package]] +name = "tungstenite" +version = "0.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ad3713a14ae247f22a728a0456a545df14acf3867f905adff84be99e23b3ad1" +dependencies = [ + "base64", + "byteorder", + "bytes", + "http", + "httparse", + "log", + "rand", + "rustls", + "sha-1 0.9.8", + "thiserror", + "url", + "utf-8", + "webpki", +] + [[package]] name = "tungstenite" version = "0.17.3" @@ -1580,7 +1705,7 @@ dependencies = [ "httparse", "log", "rand", - "sha-1", + "sha-1 0.10.0", "thiserror", "url", "utf-8", @@ -1629,6 +1754,7 @@ dependencies = [ "idna", "matches", "percent-encoding", + "serde", ] [[package]] @@ -1637,16 +1763,6 @@ version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9" -[[package]] -name = "uuid" -version = "1.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dd6469f4314d5f1ffec476e05f17cc9a78bc7a27a6a857842170bdf8d6f98d2f" -dependencies = [ - "getrandom", - "serde", -] - [[package]] name = "valuable" version = "0.1.0" @@ -1843,3 +1959,23 @@ checksum = "80d0f4e272c85def139476380b12f9ac60926689dd2e01d4923222f40580869d" dependencies = [ "winapi", ] + +[[package]] +name = "ws_stream_tungstenite" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a672ec78525bf189cefa7f1b72c55f928b3edbdb967e680ca49748ab20821045" +dependencies = [ + "async-tungstenite", + "async_io_stream", + "bitflags", + "futures-core", + "futures-io", + "futures-sink", + "futures-util", + "log", + "pharos", + "rustc_version", + "tokio", + "tungstenite 0.16.0", +] diff --git a/README.md b/README.md index 029425e..a128abf 100644 --- a/README.md +++ b/README.md @@ -4,76 +4,147 @@ A bridge between Modbus devices and MQTT. It is early days, but the plan is: -* Support custom Modbus transports (Sungrow WiNet-S has been implemented) -* Support _setting_ holding registers over MQTT -* Support optional auto-configuration of Home Assistant entities, including using [MQTT Number](https://www.home-assistant.io/integrations/number.mqtt/) et al for holding registers, to allow setting the value. +* [x] Support custom Modbus transports (Sungrow WiNet-S has been implemented) + * Modbus RTU has not been tested because I don't have a serial Modbus device, but in principle it should work. Please let me know +* [x] Support reading input registers +* [x] Support reading holding registers +* [ ] Support _setting_ holding registers +* [ ] Support optional auto-configuration of Home Assistant entities, including using [MQTT Number](https://www.home-assistant.io/integrations/number.mqtt/) et al for holding registers, to allow setting the value. +* [ ] TLS MQTT connections +* [ ] ws:// and ws:// MQTT connections + +NOTE: For the time being, this does not support MQTTv5. + +## Installing + +For now, use `cargo install` (Rust toolchain required). Soon, I will have release binaries attached to GitHub releases. In the future, there will also be Docker images made available for convenience. + +## Running + +Start the binary, passing in the URL to your MQTT server, including any credentials: + +```sh-session +$ modbus-mqtt mqtt://$MQTT_HOST[:$MQTT_PORT]/[$CUSTOM_MODBUS_TOPIC] +``` + +The supported protocols are currently just `tcp://`/`mqtt://`, but with intent to support: `mqtts://`, `ssl://`/`tls://`, `ws://`, and `wss://`. + +The default topic which ModbusMQTT monitors and to which it publishes is `modbus-mqtt`. You can vary that by changing the path portion of the MQTT URL. + +Further, you can change other MQTT options by using query params, such as setting a custom client_id: + +```sh +"mqtt://1.2.3.4/?client_id=$CUSTOM_CLIENT_ID" +``` + +For a full list of supported options, check [the MQTT client library's source code](https://github.com/bytebeamio/rumqtt/blob/c6dc1f7cfb26f6c1f676954a51b398708d49091a/rumqttc/src/lib.rs#L680-L768). + +### Connecting to Modbus devices + +To connect to a Modbus device, you need to post the connection details to MQTT under a topic of `$prefix/$connection_id/connect`. It is intended that such messages are marked as **retained** so that ModbusMQTT reconnects to your devices when it restarts. + +For instance, a simple config might be: + +```jsonc +// PUBLISH modbus-mqtt/solar-inverter/connect +{ + "host": "10.10.10.219", + "proto": "tcp", +} +``` + +If the connection is successful, you will see the following message like the following sent to the MQTT server: + +```jsonc +// modbus-mqtt/solar-inverter/state +"connected" +``` + +#### Full connection examples + +All fields accepted (optional fields show defaults) + +```jsonc +{ + // Common fields + "address_offset": 0, // optional + "unit": 1, // optional, aliased to "slave" + + // TCP: + "proto": "tcp", + "host": "1.2.3.4", + "port": 502, // optional + + // RTU / Serial: + "proto": "rtu", + "tty": "/dev/ttyACM0", + "data_bits": "Eight", // optional (TODO: accept numeric and lowercase) + // valid: Five, Six, Seven, Eight + "stop_bits": "One", // optional (TODO: accept numeric and lowercase) + // valid: One, Two + "flow_control": "None", // optional (TODO: accept lowercase) + // valid: None, Software, Hardware + "parity": "None", // optional (TODO: accept lowercase) + // valid: None, Odd, Even + + // Sungrow WiNet-S dongle + "proto": "winet-s", + "host": "1.2.3.4", +} +``` + +#### Monitoring registers + +Post to `$MODBUS_MQTT_TOPIC/$CONNECTION_ID/$TYPE/$ADDRESS` where `$TYPE` is one of `input` or `holding` with the following payload (optional fields show defaults): + +```jsonc +{ + "name": null, // OPTIONAL - gives the register a name which is used in the register MQTT topics (must be a valid topic component) + + "interval": "1m", // OPTIONAL - how often to update the registers value to MQTT + // e.g.: 3s (every 3 seconds) + // 2m (every 2 minutes) + // 1h (every 1 hour) + + "swap_bytes": false, // OPTIONAL + "swap_words": false, // OPTIONAL + + "type": "s16", // OPTIONAL + // valid: s8, s16, s32, s64 (signed) + // u8, u16, u32, u64 (unsigned) + // f32, f64 (floating point) + + "scale": 0, // OPTIONAL - number in register will be multiplied by 10^(scale) + // e.g.: to turn kW into W, you would provide scale=3 + // to turn W into kW, you would provide scale=-3 + + "offset": 0, // OPTIONAL - will be added to the final result (AFTER scaling) + // Additionally, "type" can be set to "array": + "type": "array", + "of": "u16" // The default array element is u16, but you can change it with the `of` field +} +``` + +Further, the `type` field can additionally be set to `"array"`, in which case, a `count` field must be provided. The array elements default to `"s16"` but can be overriden in the `"of"` field. + +NOTE: this is likely to change such that there is always a `count` field (with default of 1) and if provided to be greater than 1, it will be interpreted to be an array of elements of the `type` specified. + +There is some code to accept `"string"` type (with a required `length` field) but this is experimental and untested. + +##### Register shorthand + +When issuing the `connect` payload, you can optionally include `input` and/or `holding` fields as arrays containing the above register schema, as long as an `address` field is added. When present, these payloads will be replayed to the MQTT server as if the user had specified each register separately, as above. + +This is a recommended way to specify connections, but the registers are broken out separately so that they can be dynamically added to too. + +## Development + +TODO: set up something like https://hub.docker.com/r/oitc/modbus-server to test with ## Similar projects * https://github.com/Instathings/modbus2mqtt * https://github.com/TenySmart/ModbusTCP2MQTT - Sungrow inverter specific - -## Example connect config - -```json -{ - "host": "10.10.10.219", - "unit": 1, - "proto": "tcp", - "address_offset": -1, - "input": [{ - "address": 5017, - "type": "u32", - "name": "dc_power", - "swap_words": false, - "period": "3s" - }, - { - "address": 5008, - "type": "s16", - "name": "internal_temperature", - "period": "1m" - }, - { - "address": 13008, - "type": "s32", - "name": "load_power", - "swap_words": false, - "period": "3s" - }, - { - "address": 13010, - "type": "s32", - "name": "export_power", - "swap_words": false, - "period": "3s" - }, - { - "address": 13022, - "name": "battery_power", - "period": "3s" - }, - { - "address": 13023, - "name": "battery_level", - "period": "1m" - }, - { - "address": 13024, - "name": "battery_health", - "period": "10m" - }], - "hold": [{ - "address": 13058, - "name": "max_soc", - "period": "90s" - }, - { - "address": 13059, - "name": "min_soc", - "period": "90s" - }] -} -``` +* https://github.com/bohdan-s/SunGather - Sungrow inverter specific \ No newline at end of file diff --git a/modbus-mqtt/Cargo.toml b/modbus-mqtt/Cargo.toml index ec71cf1..f69dacf 100644 --- a/modbus-mqtt/Cargo.toml +++ b/modbus-mqtt/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "modbus-mqtt" -version = "0.1.0" +version = "0.2.0" edition = "2021" authors = ["Bo Jeanes "] description = "A bridge between Modbus devices and MQTT" @@ -15,18 +15,29 @@ bytes = "1.1.0" clap = { version = "3.2.12", features = ["derive", "env"] } humantime-serde = "1.1.1" itertools = "0.10.3" -rumqttc = "0.15.0" + +rumqttc = { version = "0.15.0", default-features = true, features = ["url"] } # https://github.com/bytebeamio/rumqtt/issues/446 rust_decimal = { version = "1.26.1", features = ["serde-arbitrary-precision", "serde-float", "serde_json", "maths"] } serde = { version = "1.0.139", features = ["serde_derive"] } -serde_json = "1.0.82" -serialport = { version = "4.2.0", features = ["serde"] } -tokio = { version = "1.20.0", features = ["rt", "rt-multi-thread", "time"] } -tokio-modbus = "0.5.3" -tokio-serial = "5.4.3" -tokio_modbus-winets = { version = "0.1.0", path = "../tokio_modbus-winets" } +serde_json = { version = "1.0.82", features = ["raw_value"] } +serialport = { version = "4.2.0", optional = true, features = ["serde"] } +thiserror = "1.0.33" +tokio = { version = "1.20.0", features = ["rt", "rt-multi-thread", "time", "signal"] } +tokio-modbus = { version = "0.5.3", default-features = false } +tokio-serial = { version = "5.4.3", optional = true } +tokio_modbus-winets = { version = "0.1.0", path = "../tokio_modbus-winets", optional = true, default-features = false } tracing = "0.1.36" tracing-subscriber = "0.3.15" -uuid = { version = "1.1.2", features = ["v4", "serde"] } +url = { version = "2.2.2", features = ["serde"] } [dev-dependencies] pretty_assertions = "1.2.1" + +[features] +default = ["tcp", "rtu", "winet-s"] +tcp = ["tokio-modbus/tcp"] +rtu = ["tokio-modbus/rtu", "dep:tokio-serial", "dep:serialport"] +winet-s = ["dep:tokio_modbus-winets"] +ws = ["rumqttc/websocket"] +tls = ["rustls"] +rustls = ["rumqttc/use-rustls"] diff --git a/modbus-mqtt/examples/sungrow-sh5.0rs-http.json b/modbus-mqtt/examples/sungrow-sh5.0rs-http.json new file mode 100644 index 0000000..2461328 --- /dev/null +++ b/modbus-mqtt/examples/sungrow-sh5.0rs-http.json @@ -0,0 +1,19 @@ +{ + "host": "10.10.10.219", + "unit": 1, + "proto": "winet-s", + "input": [ + { + "address": 13000, + "type": "u16", + "name": "running_state", + "period": "3s" + }, + { + "address": 13022, + "type": "s16", + "name": "battery_power", + "period": "3s" + } + ] +} \ No newline at end of file diff --git a/modbus-mqtt/examples/sungrow-sh5.0rs.json b/modbus-mqtt/examples/sungrow-sh5.0rs.json new file mode 100644 index 0000000..4de50b7 --- /dev/null +++ b/modbus-mqtt/examples/sungrow-sh5.0rs.json @@ -0,0 +1,114 @@ +{ + "host": "10.10.10.219", + "unit": 1, + "proto": "tcp", + "address_offset": -1, + "input": [ + { + "address": 5017, + "type": "u32", + "name": "dc_power", + "swap_words": true, + "period": "1s" + }, + { + "address": 13034, + "type": "u32", + "name": "active_power", + "swap_words": true, + "period": "1s" + }, + { + "address": 5008, + "type": "s16", + "name": "internal_temperature", + "period": "1m", + "scale": -1 + }, + { + "address": 13008, + "type": "s32", + "name": "load_power", + "swap_words": true, + "period": "1s" + }, + { + "address": 13010, + "type": "s32", + "name": "export_power", + "swap_words": true, + "period": "1s" + }, + { + "address": 13022, + "name": "battery_power", + "period": "1s" + }, + { + "address": 13023, + "name": "battery_level", + "period": "1m", + "scale": -1 + }, + { + "address": 13024, + "name": "battery_health", + "period": "10m", + "scale": -1 + }, + { + "address": 5036, + "name": "grid_frequency", + "period": "1m" + }, + { + "address": 5019, + "name": "phase_a_voltage", + "period": "1m" + }, + { + "address": 13031, + "name": "phase_a_current", + "period": "1m" + }, + { + "address": 5011, + "name": "mppt1_voltage" + }, + { + "address": 5012, + "name": "mppt1_current" + }, + { + "address": 5012, + "name": "mppt2_voltage" + }, + { + "address": 5013, + "name": "mppt2_current" + } + ], + "holding": [ + { + "address": 13058, + "name": "max_soc", + "period": "90s", + "scale": -1 + }, + { + "address": 13059, + "name": "min_soc", + "period": "90s", + "scale": -1 + }, + { + "address": 13100, + "name": "battery_reserve" + }, + { + "address": 33148, + "name": "forced_battery_power", + "scale": 1 + } + ] +} \ No newline at end of file diff --git a/modbus-mqtt/src/bin/run.rs b/modbus-mqtt/src/bin/run.rs new file mode 100644 index 0000000..8f7fd74 --- /dev/null +++ b/modbus-mqtt/src/bin/run.rs @@ -0,0 +1,80 @@ +use clap::Parser; +use modbus_mqtt::{server, Result}; +use rumqttc::MqttOptions; +use tokio::select; +use url::Url; + +#[derive(Parser, Debug)] +#[clap( + name = "modbus-mqtt", + version, + author, + about = "A bridge between Modbus and MQTT" +)] +struct Cli { + #[clap( + env = "MQTT_URL", + // validator = "is_mqtt_url", + default_value = "mqtt://localhost:1883/modbus-mqtt", + value_hint = clap::ValueHint::Url, + help = "Pass the topic prefix as the URL path" + )] + url: Url, +} + +#[tokio::main] +async fn main() -> Result<()> { + tracing_subscriber::fmt::init(); + + let Cli { mut url } = Cli::parse(); + + let mut prefix = url + .path() + .trim_start_matches('/') + .trim_end_matches('/') + .to_owned(); + + let options: MqttOptions = match url.clone().try_into() { + Ok(options) => options, + Err(rumqttc::OptionError::ClientId) => { + let url = url + .query_pairs_mut() + .append_pair("client_id", env!("CARGO_PKG_NAME")) + .finish() + .clone(); + url.try_into()? + } + Err(other) => return Err(other.into()), + }; + + if prefix.is_empty() { + prefix = options.client_id(); + } + + let shutdown = async move { + let ctrl_c = tokio::signal::ctrl_c(); + + #[cfg(unix)] + { + use tokio::signal::unix::{signal, SignalKind}; + + let mut term = signal(SignalKind::terminate()).unwrap(); + let mut int = signal(SignalKind::interrupt()).unwrap(); + let mut hup = signal(SignalKind::hangup()).unwrap(); + + select! { + _ = ctrl_c => {}, + _ = term.recv() => {}, + _ = int.recv() => {}, + _ = hup.recv() => {}, + } + } + + #[cfg(not(unix))] + ctrl_c.await; + }; + + server::run(prefix, options, shutdown).await?; + + Ok(()) +} diff --git a/modbus-mqtt/src/connection.rs b/modbus-mqtt/src/connection.rs deleted file mode 100644 index bd10053..0000000 --- a/modbus-mqtt/src/connection.rs +++ /dev/null @@ -1,4 +0,0 @@ -pub struct Connection { - // connect: Connect, - context: tokio_modbus::client::Context, -} diff --git a/modbus-mqtt/src/error.rs b/modbus-mqtt/src/error.rs new file mode 100644 index 0000000..372435b --- /dev/null +++ b/modbus-mqtt/src/error.rs @@ -0,0 +1,55 @@ +use thiserror::Error; + +#[derive(Error, Debug)] +#[non_exhaustive] +pub enum Error { + #[error(transparent)] + IOError(#[from] std::io::Error), + + #[error(transparent)] + MQTTOptionError(#[from] rumqttc::OptionError), + + #[error(transparent)] + MQTTClientError(#[from] rumqttc::ClientError), + + #[error(transparent)] + MQTTConnectionError(#[from] rumqttc::ConnectionError), + + #[error(transparent)] + InvalidSocketAddr(#[from] std::net::AddrParseError), + + #[error(transparent)] + SerialError(#[from] tokio_serial::Error), + + #[error(transparent)] + ParseIntError(#[from] std::num::ParseIntError), + + #[error(transparent)] + JSONError(#[from] serde_json::Error), + + #[error("RecvError")] + RecvError, + + #[error("SendError")] + SendError, + + #[error("Unrecognised modbus protocol")] + UnrecognisedModbusProtocol, + + #[error("{0}")] + Other(std::borrow::Cow<'static, str>), + + #[error("Unknown")] + Unknown, +} + +impl From for Error { + fn from(s: String) -> Self { + Self::Other(s.into()) + } +} +impl From<&'static str> for Error { + fn from(s: &'static str) -> Self { + Self::Other(s.into()) + } +} diff --git a/modbus-mqtt/src/lib.rs b/modbus-mqtt/src/lib.rs new file mode 100644 index 0000000..b15521b --- /dev/null +++ b/modbus-mqtt/src/lib.rs @@ -0,0 +1,10 @@ +mod shutdown; + +pub mod modbus; +pub mod mqtt; +pub mod server; + +mod error; +pub use error::Error; + +pub type Result = std::result::Result; diff --git a/modbus-mqtt/src/main.rs b/modbus-mqtt/src/main.rs deleted file mode 100644 index c4ad5ad..0000000 --- a/modbus-mqtt/src/main.rs +++ /dev/null @@ -1,448 +0,0 @@ -use rumqttc::{self, AsyncClient, Event, Incoming, LastWill, MqttOptions, Publish, QoS}; -use serde::Serialize; -use serde_json::json; -use std::{collections::HashMap, time::Duration}; -use tokio::{sync::mpsc, sync::oneshot, time::MissedTickBehavior}; -use tokio_modbus::prelude::*; -use tracing::{debug, error, info}; - -use clap::Parser; - -mod modbus; - -#[derive(Parser)] -struct Cli { - mqtt_host: String, - - #[clap(short = 'n', long, default_value = "modbus")] - mqtt_name: String, - - #[clap(short = 'p', long, default_value_t = 1883)] - mqtt_port: u16, - - #[clap(short = 'u', long, env = "MQTT_USER")] - mqtt_user: Option, - - #[clap(short = 'P', long, env)] - mqtt_password: Option, - - #[clap(short = 't', long, default_value = "modbus-mqtt")] - // Where to listen for commands - mqtt_topic_prefix: String, -} - -#[derive(Serialize)] -#[serde(rename_all = "lowercase")] -enum MainStatus { - Running, - Stopped, -} - -#[tokio::main(worker_threads = 3)] -async fn main() { - tracing_subscriber::fmt::init(); - - let args = Cli::parse(); - - let (registry_tx, registry_rx) = mpsc::channel::(32); - let (dispatcher_tx, dispatcher_rx) = mpsc::channel::(32); - - // Modbus connection registry - let registry_handle = { - let prefix = args.mqtt_topic_prefix.clone(); - tokio::spawn(connection_registry(prefix, dispatcher_tx, registry_rx)) - }; - - // MQTT Dispatcher - let dispatcher_handle = { - let prefix = args.mqtt_topic_prefix.clone(); - let mut options = MqttOptions::new( - env!("CARGO_PKG_NAME"), - args.mqtt_host.as_str(), - args.mqtt_port, - ); - if let (Some(u), Some(p)) = (args.mqtt_user, args.mqtt_password) { - options.set_credentials(u, p); - } - options.set_keep_alive(Duration::from_secs(5)); // TODO: make this configurable - - tokio::spawn(mqtt_dispatcher(options, prefix, registry_tx, dispatcher_rx)) - }; - - registry_handle.await.unwrap(); - dispatcher_handle.await.unwrap(); -} - -#[derive(Debug)] -enum DispatchCommand { - Publish { topic: String, payload: Vec }, -} -#[tracing::instrument(level = "debug")] -async fn mqtt_dispatcher( - mut options: MqttOptions, - prefix: String, - registry: mpsc::Sender, - mut rx: mpsc::Receiver, -) { - info!("Connecting to MQTT broker..."); - - options.set_last_will(LastWill { - topic: format!("{}/status", prefix), - message: serde_json::to_vec(&json!({ - "status": MainStatus::Stopped, - })) - .unwrap() - .into(), - qos: QoS::AtMostOnce, - retain: false, - }); - - let (client, mut eventloop) = AsyncClient::new(options, 10); - - client - .publish( - format!("{}/status", prefix), - QoS::AtMostOnce, - false, - serde_json::to_vec(&json!({ - "status": MainStatus::Running, - })) - .unwrap(), - ) - .await - .unwrap(); - - client - .subscribe(format!("{}/connect/#", prefix), QoS::AtMostOnce) - .await - .unwrap(); - - let rx_loop_handler = { - let client = client.clone(); - tokio::spawn(async move { - info!("Start dispatcher rx loop"); - while let Some(command) = rx.recv().await { - match command { - DispatchCommand::Publish { topic, payload } => { - client - .publish(topic, QoS::AtMostOnce, false, payload) - .await - .unwrap(); - } - } - } - }) - }; - - while let Ok(event) = eventloop.poll().await { - use Event::{Incoming as In, Outgoing as Out}; - - match event { - Out(_) => (), - In(Incoming::ConnAck(_)) => info!("Connected to MQTT!"), - In(Incoming::PingResp | Incoming::SubAck(_)) => (), - - In(Incoming::Publish(Publish { topic, payload, .. })) => { - debug!("{} -> {:?}", &topic, &payload); - - match topic.split('/').collect::>()[..] { - [p, "connect", conn_name] if p == prefix.as_str() => { - registry - .send(RegistryCommand::Connect { - id: conn_name.to_string(), - details: payload, - }) - .await - .unwrap(); - } - _ => (), - }; - } - _ => { - debug!("{:?}", event); - } - } - } - - rx_loop_handler.await.unwrap(); -} - -type ConnectionId = String; - -#[derive(Debug)] -enum RegistryCommand { - Connect { - id: ConnectionId, - details: bytes::Bytes, - }, - Disconnect(ConnectionId), -} - -type RegistryDb = HashMap>; - -#[tracing::instrument(level = "debug")] -async fn connection_registry( - prefix: String, - dispatcher: mpsc::Sender, - mut rx: mpsc::Receiver, -) { - info!("Starting connection registry..."); - let mut db: RegistryDb = HashMap::new(); - - while let Some(command) = rx.recv().await { - use RegistryCommand::*; - match command { - Disconnect(id) => { - if let Some(handle) = db.remove(&id) { - handle.abort(); - } - } - Connect { id, details } => { - info!(id, payload = ?details, "Establishing connection"); - let prefix = prefix.clone(); - let dispatcher = dispatcher.clone(); - - if let Some(handle) = db.remove(&id) { - handle.abort(); - } - - db.insert( - id.clone(), - tokio::spawn(handle_connect(dispatcher, id, prefix, details)), - ); - } - _ => error!("unimplemented"), - } - } -} - -#[derive(Clone, Copy, Debug)] -enum ModbusReadType { - Input, - Hold, -} - -#[derive(Debug)] -enum ModbusCommand { - Read(ModbusReadType, u16, u8, ModbusResponse), - Write(u16, Vec, ModbusResponse), -} - -type ModbusResponse = oneshot::Sender, std::io::Error>>; - -#[tracing::instrument(level = "debug")] -async fn handle_connect( - dispatcher: mpsc::Sender, - id: ConnectionId, - topic_prefix: String, - payload: bytes::Bytes, -) { - use modbus::config::*; - use modbus::ConnectState; - info!("Starting connection handler for {}", id); - match serde_json::from_slice::(&payload) { - Ok(connect) => { - let unit = connect.unit; - - let mut modbus = match connect.settings { - ModbusProto::SungrowWiNetS { ref host } => { - tokio_modbus_winets::connect_slave(host, unit) - .await - .unwrap() - } - ModbusProto::Tcp { ref host, port } => { - let socket_addr = format!("{}:{}", host, port).parse().unwrap(); - tcp::connect_slave(socket_addr, unit).await.unwrap() - } - ModbusProto::Rtu { - ref tty, - baud_rate, - data_bits, - stop_bits, - flow_control, - parity, - } => { - let builder = tokio_serial::new(tty, baud_rate) - .data_bits(data_bits) - .flow_control(flow_control) - .parity(parity) - .stop_bits(stop_bits); - let port = tokio_serial::SerialStream::open(&builder).unwrap(); - rtu::connect_slave(port, unit).await.unwrap() - } - }; - let status = modbus::ConnectStatus { - connect: connect.clone(), - status: ConnectState::Connected, - }; - dispatcher - .send(DispatchCommand::Publish { - topic: format!("{}/status/{}", topic_prefix, id), - payload: serde_json::to_vec(&status).unwrap(), - }) - .await - .unwrap(); - - let (modbus_tx, mut modbus_rx) = mpsc::channel::(32); - tokio::spawn(async move { - while let Some(command) = modbus_rx.recv().await { - match command { - ModbusCommand::Read(read_type, address, count, responder) => { - let response = match read_type { - ModbusReadType::Input => { - modbus.read_input_registers(address, count as u16) - } - ModbusReadType::Hold => { - modbus.read_holding_registers(address, count as u16) - } - }; - - responder.send(response.await).unwrap(); - } - ModbusCommand::Write(address, data, responder) => { - responder - .send( - modbus - .read_write_multiple_registers( - address, - data.len() as u16, - address, - &data[..], - ) - .await, - ) - .unwrap(); - } - } - } - }); - - use itertools::Itertools; - for (duration, registers) in &connect.input.into_iter().group_by(|r| r.interval) { - let registers_prefix = format!("{}/input/{}", topic_prefix, id); - - tokio::spawn(watch_registers( - ModbusReadType::Input, - connect.address_offset, - duration, - registers.collect(), - modbus_tx.clone(), - dispatcher.clone(), - registers_prefix, - )); - } - for (duration, registers) in &connect.hold.into_iter().group_by(|r| r.interval) { - let registers_prefix = format!("{}/hold/{}", topic_prefix, id); - - tokio::spawn(watch_registers( - ModbusReadType::Hold, - connect.address_offset, - duration, - registers.collect(), - modbus_tx.clone(), - dispatcher.clone(), - registers_prefix, - )); - } - } - Err(err) => { - dispatcher - .send(DispatchCommand::Publish { - topic: format!("{}/status/{}", topic_prefix, id), - payload: serde_json::to_vec(&json!({ - "status": ConnectState::Errored, - "error": format!("Invalid config: {}", err), - })) - .unwrap(), - }) - .await - .unwrap(); - } - } -} - -#[tracing::instrument(level = "debug")] -async fn watch_registers( - read_type: ModbusReadType, - address_offset: i8, - duration: Duration, - registers: Vec, - modbus: mpsc::Sender, - dispatcher: mpsc::Sender, - registers_prefix: String, -) -> ! { - let mut interval = tokio::time::interval(duration); - interval.set_missed_tick_behavior(MissedTickBehavior::Delay); - - loop { - interval.tick().await; - for r in registers.iter() { - let address = if address_offset >= 0 { - r.address.checked_add(address_offset as u16) - } else { - r.address.checked_sub(address_offset.unsigned_abs() as u16) - }; - if let Some(address) = address { - let size = r.parse.value_type.size(); - debug!( - name = r.name.as_ref().unwrap_or(&"".to_string()), - address, - size, - register_type = ?read_type, - value_type = r.parse.value_type.type_name(), - "Polling register", - ); - - let (tx, rx) = oneshot::channel(); - - modbus - .send(ModbusCommand::Read(read_type, address, size, tx)) - .await - .unwrap(); - - // FIXME: definitely getting errors here that need to be handled - // - // thread 'tokio-runtime-worker' panicked at 'called `Result::unwrap()` on an `Err` value: Error { kind: UnexpectedEof, message: "failed to fill whole buffer" }' - // thread 'tokio-runtime-worker' panicked at 'called `Result::unwrap()` on an `Err` value: Custom { kind: InvalidData, error: "Invalid data length: 0" }' - // thread 'tokio-runtime-worker' panicked at 'called `Result::unwrap()` on an `Err` value: Os { code: 36, kind: Uncategorized, message: "Operation now in progress" }' - // thread 'tokio-runtime-worker' panicked at 'called `Result::unwrap()` on an `Err` value: Os { code: 35, kind: WouldBlock, message: "Resource temporarily unavailable" } - // - // Splitting out the two awaits so I can see if all of the above panics come from the same await or some from one vs the other: - let response = rx.await.unwrap(); // await may have errorer on receiving - let words = response.unwrap(); // received message is also a result which may be a (presumably Modbus?) error - - let swapped_words = r.apply_swaps(&words); - - let value = r.parse_words(&swapped_words); - - debug!( - name = r.name.as_ref().unwrap_or(&"".to_string()), - address, - %value, - raw = ?words, - "Received value", - ); - - let payload = serde_json::to_vec(&json!({ "value": value, "raw": words })).unwrap(); - - dispatcher - .send(DispatchCommand::Publish { - topic: format!("{}/{}", registers_prefix, r.address), - payload: payload.clone(), - }) - .await - .unwrap(); - - if let Some(name) = &r.name { - dispatcher - .send(DispatchCommand::Publish { - topic: format!("{}/{}", registers_prefix, name), - payload, - }) - .await - .unwrap(); - } - } - } - } -} diff --git a/modbus-mqtt/src/modbus/config.rs b/modbus-mqtt/src/modbus/config.rs deleted file mode 100644 index 25c8c51..0000000 --- a/modbus-mqtt/src/modbus/config.rs +++ /dev/null @@ -1,507 +0,0 @@ -use serde::{Deserialize, Serialize}; -use std::time::Duration; - -#[cfg(test)] -use serde_json::json; - -#[derive(Clone, Debug, Serialize, Deserialize)] -#[serde(tag = "proto", rename_all = "lowercase")] -pub enum ModbusProto { - Tcp { - host: String, - - #[serde(default = "default_modbus_port")] - port: u16, - }, - #[serde(rename_all = "lowercase")] - Rtu { - // tty: std::path::PathBuf, - tty: String, - baud_rate: u32, - - #[serde(default = "default_modbus_data_bits")] - data_bits: tokio_serial::DataBits, // TODO: allow this to be represented as a number instead of string - - #[serde(default = "default_modbus_stop_bits")] - stop_bits: tokio_serial::StopBits, // TODO: allow this to be represented as a number instead of string - - #[serde(default = "default_modbus_flow_control")] - flow_control: tokio_serial::FlowControl, - - #[serde(default = "default_modbus_parity")] - parity: tokio_serial::Parity, - }, - #[serde(rename = "winet-s")] - SungrowWiNetS { host: String }, -} - -fn default_modbus_port() -> u16 { - 502 -} - -fn default_modbus_data_bits() -> tokio_serial::DataBits { - tokio_serial::DataBits::Eight -} - -fn default_modbus_stop_bits() -> tokio_serial::StopBits { - tokio_serial::StopBits::One -} - -fn default_modbus_flow_control() -> tokio_serial::FlowControl { - tokio_serial::FlowControl::None -} - -fn default_modbus_parity() -> tokio_serial::Parity { - tokio_serial::Parity::None -} - -#[derive(Clone, Default, Debug, PartialEq, Eq, Serialize, Deserialize)] -#[serde(rename_all = "lowercase", default)] -pub struct RegisterNumericAdjustment { - pub scale: i8, // powers of 10 (0 = no adjustment, 1 = x10, -1 = /10) - pub offset: i8, - // precision: Option, -} - -#[derive(Clone, Default, Debug, PartialEq, Eq, Serialize, Deserialize)] -#[serde(rename_all = "lowercase")] -pub enum RegisterNumeric { - U8, - #[default] - U16, - U32, - U64, - - #[serde(alias = "s8")] - I8, - #[serde(alias = "s16")] - I16, - #[serde(alias = "s32")] - I32, - #[serde(alias = "s64")] - I64, - - F32, - F64, -} - -impl RegisterNumeric { - // Modbus limits sequential reads to 125 apparently, so 8-bit should be fine - https://github.com/slowtec/tokio-modbus/issues/112#issuecomment-1095316069= - fn size(&self) -> u8 { - use RegisterNumeric::*; - // Each Modbus register holds 16-bits, so count is half what the byte count would be - match self { - U8 | I8 => 1, - U16 | I16 => 1, - U32 | I32 | F32 => 2, - U64 | I64 | F64 => 4, - } - } - - fn type_name(&self) -> String { - format!("{:?}", *self).to_lowercase() - } -} - -#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] -#[serde(tag = "type", rename = "string")] -pub struct RegisterString { - length: u8, -} - -#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] -#[serde(tag = "type", rename = "array")] -pub struct RegisterArray { - count: u8, - - #[serde(default)] - of: RegisterNumeric, - - // Arrays are only of numeric types, so we can apply an adjustment here - #[serde(flatten, skip_serializing_if = "IsDefault::is_default")] - adjust: RegisterNumericAdjustment, -} - -impl Default for RegisterArray { - fn default() -> Self { - Self { - count: 1, - of: Default::default(), - adjust: Default::default(), - } - } -} - -#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] -#[serde(untagged)] -pub enum RegisterValueType { - Numeric { - #[serde(rename = "type", default)] - of: RegisterNumeric, - - #[serde(flatten, skip_serializing_if = "IsDefault::is_default")] - adjust: RegisterNumericAdjustment, - }, - Array(RegisterArray), - String(RegisterString), -} - -impl RegisterValueType { - pub fn type_name(&self) -> String { - match *self { - RegisterValueType::Numeric { ref of, .. } => of.type_name(), - RegisterValueType::Array(_) => "array".to_owned(), - RegisterValueType::String(_) => "string".to_owned(), - } - } -} - -impl Default for RegisterValueType { - fn default() -> Self { - RegisterValueType::Numeric { - of: Default::default(), - adjust: Default::default(), - } - } -} - -impl RegisterValueType { - // Modbus limits sequential reads to 125 apparently, so 8-bit should be fine - https://github.com/slowtec/tokio-modbus/issues/112#issuecomment-1095316069= - pub fn size(&self) -> u8 { - use RegisterValueType::*; - - match self { - Numeric { of, .. } => of.size(), - String(RegisterString { length }) => *length, - Array(RegisterArray { of, count, .. }) => of.size() * count, - } - } -} - -#[derive(Clone, Default, Debug, PartialEq, Eq, Serialize, Deserialize)] -#[serde(transparent)] -pub struct Swap(pub bool); - -trait IsDefault { - fn is_default(&self) -> bool; -} -impl IsDefault for T -where - T: Default + PartialEq, -{ - fn is_default(&self) -> bool { - *self == Default::default() - } -} - -#[derive(Clone, Debug, PartialEq, Eq, Default, Serialize, Deserialize)] -pub struct RegisterParse { - #[serde(default, skip_serializing_if = "IsDefault::is_default")] - pub swap_bytes: Swap, - - #[serde(default, skip_serializing_if = "IsDefault::is_default")] - pub swap_words: Swap, - - #[serde(flatten, skip_serializing_if = "IsDefault::is_default")] - pub value_type: RegisterValueType, -} - -#[derive(Clone, Debug, Serialize, Deserialize)] -pub struct Register { - pub address: u16, - - #[serde(skip_serializing_if = "Option::is_none")] - pub name: Option, - - #[serde(flatten, default, skip_serializing_if = "IsDefault::is_default")] - pub parse: RegisterParse, - - #[serde( - with = "humantime_serde", - default = "default_register_interval", - alias = "period", - alias = "duration" - )] - pub interval: Duration, -} - -fn default_register_interval() -> Duration { - Duration::from_secs(60) -} - -#[derive(Clone, Serialize, Deserialize)] -pub struct Connect { - #[serde(flatten)] - pub settings: ModbusProto, - - #[serde(default, skip_serializing_if = "Vec::is_empty")] - pub input: Vec, - - #[serde(default, skip_serializing_if = "Vec::is_empty")] - pub hold: Vec, - - #[serde( - alias = "slave", - default = "tokio_modbus::slave::Slave::broadcast", - with = "Unit" - )] - pub unit: crate::modbus::Unit, - - #[serde(default = "default_address_offset")] - pub address_offset: i8, -} - -#[derive(Serialize, Deserialize)] -#[serde(remote = "tokio_modbus::slave::Slave")] -struct Unit(crate::modbus::UnitId); - -fn default_address_offset() -> i8 { - 0 -} - -#[test] -fn parse_minimal_tcp_connect_config() { - let result = serde_json::from_value::(json!({ - "proto": "tcp", - "host": "1.1.1.1" - })); - - let connect = result.unwrap(); - assert!(matches!( - connect.settings, - ModbusProto::Tcp { - ref host, - port: 502 - } if host == "1.1.1.1" - )) -} - -#[test] -fn parse_full_tcp_connect_config() { - let _ = serde_json::from_value::(json!({ - "proto": "tcp", - "host": "10.10.10.219", - "unit": 1, - "address_offset": -1, - "input": [ - { - "address": 5017, - "type": "u32", - "name": "dc_power", - "swap_words": false, - "period": "3s" - }, - { - "address": 5008, - "type": "s16", - "name": "internal_temperature", - "period": "1m" - }, - { - "address": 13008, - "type": "s32", - "name": "load_power", - "swap_words": false, - "period": "3s" - }, - { - "address": 13010, - "type": "s32", - "name": "export_power", - "swap_words": false, - "period": "3s" - }, - { - "address": 13022, - "name": "battery_power", - "period": "3s" - }, - { - "address": 13023, - "name": "battery_level", - "period": "1m" - }, - { - "address": 13024, - "name": "battery_health", - "period": "10m" - } - ], - "hold": [ - { - "address": 13058, - "name": "max_soc", - "period": "90s" - }, - { - "address": 13059, - "name": "min_soc", - "period": "90s" - } - ] - })) - .unwrap(); -} - -#[test] -fn parse_minimal_rtu_connect_config() { - let result = serde_json::from_value::(json!({ - "proto": "rtu", - "tty": "/dev/ttyUSB0", - "baud_rate": 9600, - })); - - let connect = result.unwrap(); - use tokio_serial::*; - assert!(matches!( - connect.settings, - ModbusProto::Rtu { - ref tty, - baud_rate: 9600, - data_bits: DataBits::Eight, - stop_bits: StopBits::One, - flow_control: FlowControl::None, - parity: Parity::None, - .. - } if tty == "/dev/ttyUSB0" - )) -} - -#[test] -fn parse_complete_rtu_connect_config() { - let result = serde_json::from_value::(json!({ - "proto": "rtu", - "tty": "/dev/ttyUSB0", - "baud_rate": 12800, - - // TODO: make lowercase words work - "data_bits": "Seven", // TODO: make 7 work - "stop_bits": "Two", // TODO: make 2 work - "flow_control": "Software", - "parity": "Even", - })); - - let connect = result.unwrap(); - use tokio_serial::*; - assert!(matches!( - connect.settings, - ModbusProto::Rtu { - ref tty, - baud_rate: 12800, - data_bits: DataBits::Seven, - stop_bits: StopBits::Two, - flow_control: FlowControl::Software, - parity: Parity::Even, - .. - } if tty == "/dev/ttyUSB0" - ),); -} - -#[test] -fn parse_empty_register_parser_defaults() { - let empty = serde_json::from_value::(json!({})); - assert!(matches!( - empty.unwrap(), - RegisterParse { - swap_bytes: Swap(false), - swap_words: Swap(false), - value_type: RegisterValueType::Numeric { - of: RegisterNumeric::U16, - adjust: RegisterNumericAdjustment { - scale: 0, - offset: 0, - } - } - } - )); -} - -#[test] -fn parse_register_parser_type() { - let result = serde_json::from_value::(json!({ - "type": "s32" - })); - assert!(matches!( - result.unwrap().value_type, - RegisterValueType::Numeric { - of: RegisterNumeric::I32, - .. - } - )); -} - -#[test] -fn parse_register_parser_array() { - let result = serde_json::from_value::(json!({ - "type": "array", - "of": "s32", - "count": 10, - })); - let payload = result.unwrap(); - // println!("{:?}", payload); - // println!("{}", serde_json::to_string_pretty(&payload).unwrap()); - - assert!(matches!( - payload.value_type, - RegisterValueType::Array(RegisterArray { - of: RegisterNumeric::I32, - count: 10, - .. - }) - )); -} - -#[test] -fn parse_register_parser_array_implicit_u16() { - let result = serde_json::from_value::(json!({ - "type": "array", - "count": 10, - })); - let payload = result.unwrap(); - // println!("{:?}", payload); - // println!("{}", serde_json::to_string_pretty(&payload).unwrap()); - - assert!(matches!( - payload.value_type, - RegisterValueType::Array(RegisterArray { - of: RegisterNumeric::U16, - count: 10, - .. - }) - )); -} - -#[test] -fn parse_register_parser_string() { - let result = serde_json::from_value::(json!({ - "type": "string", - "length": 10, - })); - let payload = result.unwrap(); - // println!("{:?}", payload); - // println!("{}", serde_json::to_string_pretty(&payload).unwrap()); - - assert!(matches!( - payload.value_type, - RegisterValueType::String(RegisterString { length: 10, .. }) - )); -} - -#[test] -fn parse_register_parser_scale_etc() { - let result = serde_json::from_value::(json!({ - "type": "s32", - "scale": -1, - "offset": 20, - })); - assert!(matches!( - result.unwrap().value_type, - RegisterValueType::Numeric { - of: RegisterNumeric::I32, - adjust: RegisterNumericAdjustment { - scale: -1, - offset: 20 - } - } - )); -} diff --git a/modbus-mqtt/src/modbus/connection.rs b/modbus-mqtt/src/modbus/connection.rs new file mode 100644 index 0000000..1521221 --- /dev/null +++ b/modbus-mqtt/src/modbus/connection.rs @@ -0,0 +1,529 @@ +use super::Word; +use crate::modbus::{self, register}; +use crate::mqtt::Scopable; +use crate::Error; +use rust_decimal::prelude::Zero; +use serde::Deserialize; +use tokio::sync::oneshot; +use tokio::{select, sync::mpsc}; +use tokio_modbus::client::{rtu, tcp, Context as ModbusClient}; +use tracing::{debug, error, warn}; + +use crate::{mqtt, shutdown::Shutdown}; + +use super::register::RegisterType; + +pub(crate) async fn run( + config: Config, + mqtt: mqtt::Handle, + shutdown: Shutdown, +) -> crate::Result { + let (handle_tx, handle_rx) = tokio::sync::oneshot::channel(); + + tokio::spawn(async move { + // Can unwrap because if MQTT handler is bad, we have nothing to do here. + mqtt.publish("state", "connecting").await.unwrap(); + + let address_offset = config.address_offset; + + match config.settings.connect(config.unit).await { + Ok(client) => { + // Can unwrap because if MQTT handler is bad, we have nothing to do here. + mqtt.publish("state", "connected").await.unwrap(); + + let (tx, rx) = mpsc::channel(32); + + let conn = Connection { + address_offset, + client, + mqtt: mqtt.clone(), + shutdown: shutdown.clone(), // Important, so that we can publish "disconnected" below + rx, + tx, + }; + + handle_tx.send(Ok(conn.handle())).unwrap(); + + if let Err(error) = conn.run().await { + error!(?error, "Modbus connection failed"); + } + + // we are shutting down here, so don't care if this fails + let send = mqtt.publish("state", "disconnected").await; + debug!(?config, ?send, "shutting down modbus connection"); + } + Err(error) => handle_tx.send(Err(error)).unwrap(), + } + }); + + handle_rx.await.map_err(|_| crate::Error::RecvError)? +} + +struct Connection { + client: ModbusClient, + address_offset: i8, + mqtt: mqtt::Handle, + shutdown: Shutdown, + rx: mpsc::Receiver, + tx: mpsc::Sender, +} + +#[derive(Debug)] +pub struct Handle { + tx: mpsc::Sender, +} + +impl Handle { + pub async fn write_register(&self, address: u16, data: Vec) -> crate::Result> { + let (tx, rx) = oneshot::channel(); + self.tx + .send(Command::Write(address, data, tx)) + .await + .map_err(|_| Error::SendError)?; + rx.await.map_err(|_| Error::RecvError)? + } + pub async fn read_input_register( + &self, + address: u16, + quantity: u8, + ) -> crate::Result> { + self.read_register(RegisterType::Input, address, quantity) + .await + } + pub async fn read_holding_register( + &self, + address: u16, + quantity: u8, + ) -> crate::Result> { + self.read_register(RegisterType::Holding, address, quantity) + .await + } + + async fn read_register( + &self, + reg_type: RegisterType, + address: u16, + quantity: u8, + ) -> crate::Result> { + let (tx, rx) = oneshot::channel(); + self.tx + .send(Command::Read(reg_type, address, quantity, tx)) + .await + .map_err(|_| Error::SendError)?; + rx.await.map_err(|_| Error::RecvError)? + } +} + +type Response = oneshot::Sender>>; + +#[derive(Debug)] +enum Command { + Read(RegisterType, u16, u8, Response), + Write(u16, Vec, Response), +} + +impl Connection { + pub async fn run(mut self) -> crate::Result<()> { + let mut registers_rx = register::subscribe(&self.mqtt).await?; + + loop { + select! { + Some(cmd) = self.rx.recv() => { self.process_command(cmd).await; }, + + Some((reg_type, reg)) = registers_rx.recv() => { + debug!(?reg_type, ?reg); + let scope = format!( + "{}/{}", + match ®_type { + RegisterType::Input => "input", + RegisterType::Holding => "holding", + }, + reg.address + ); + let mqtt = self.mqtt.scoped(scope); + let modbus = self.handle(); + register::Monitor::new( + reg.register, + reg_type, + reg.address, + mqtt, + modbus, + ) + .run() + .await; + }, + + _ = self.shutdown.recv() => { + return Ok(()); + } + } + } + } + + fn handle(&self) -> Handle { + Handle { + tx: self.tx.clone(), + } + } + + // TODO: if we get a new register definition for an existing register, how do we avoid redundant (and possibly + // conflicting) tasks? Should MQTT component only allow one subscriber per topic filter, replacing the old one + // when it gets a new subscribe request? + // IDEA: Allow providing a subscription ID which _replaces_ any existing subscription with the same ID + + /// Apply address offset to address. + /// + /// Panics if offset would overflow or underflow the address. + fn adjust_address(&self, address: u16) -> u16 { + if self.address_offset.is_zero() { + return address; + } + + // TODO: use `checked_add_signed()` once stabilised: https://doc.rust-lang.org/std/primitive.u16.html#method.checked_add_signed + let adjusted_address = if self.address_offset >= 0 { + address.checked_add(self.address_offset as u16) + } else { + address.checked_sub(self.address_offset.unsigned_abs() as u16) + }; + + if let Some(address) = adjusted_address { + address + } else { + error!(address, offset = self.address_offset,); + address + // panic!("Address offset would underflow/overflow") + } + } + + async fn process_command(&mut self, cmd: Command) { + use tokio_modbus::prelude::Reader; + + let (tx, response) = match cmd { + Command::Read(RegisterType::Input, address, count, tx) => { + let address = self.adjust_address(address); + ( + tx, + self.client + .read_input_registers(address, count as u16) + .await, + ) + } + Command::Read(RegisterType::Holding, address, count, tx) => { + let address = self.adjust_address(address); + ( + tx, + self.client + .read_holding_registers(address, count as u16) + .await, + ) + } + Command::Write(address, data, tx) => { + let address = self.adjust_address(address); + ( + tx, + self.client + .read_write_multiple_registers( + address, + data.len() as u16, + address, + &data[..], + ) + .await, + ) + } + }; + + // This might be transient, so don't kill connection. We may be able to discriminate on the error to determine + // which errors are transient and which are conclusive. + // + // Some errors that we have observed: + // + // Error { kind: UnexpectedEof, message: "failed to fill whole buffer" }' + // Custom { kind: InvalidData, error: "Invalid data length: 0" }' + // Os { code: 36, kind: Uncategorized, message: "Operation now in progress" }' + // Os { code: 35, kind: WouldBlock, message: "Resource temporarily unavailable" } + // + if let Err(error) = &response { + match error.kind() { + std::io::ErrorKind::UnexpectedEof => { + // THIS happening feels like a bug either in how I am using tokio_modbus or in tokio_modbus. It seems + // like the underlying buffers get all messed up and restarting doesn't always fix it unless I wait a + // few seconds. I might need to get help from someone to figure it out. + error!(?error, "Connection error, may not be recoverable"); + } + _ => error!(?error), + } + } + + // This probably just means that the register task died or is no longer monitoring the response. + if let Err(response) = tx.send(response.map_err(Into::into)) { + warn!(?response, "error sending response"); + } + } +} + +#[derive(Debug, Deserialize)] +pub(crate) struct Config { + #[serde(flatten)] + pub settings: ModbusProto, + + #[serde( + alias = "slave", + default = "tokio_modbus::slave::Slave::broadcast", + with = "Unit" + )] + pub unit: modbus::Unit, + + #[serde(default)] + pub address_offset: i8, +} + +#[derive(Deserialize)] +#[serde(remote = "tokio_modbus::slave::Slave")] +pub(crate) struct Unit(crate::modbus::UnitId); + +#[derive(Clone, Debug, Deserialize)] +#[serde(tag = "proto", rename_all = "lowercase")] +pub(crate) enum ModbusProto { + #[cfg(feature = "tcp")] + Tcp { + host: String, + + #[serde(default = "default_modbus_port")] + port: u16, + }, + #[cfg(feature = "rtu")] + #[serde(rename_all = "lowercase")] + Rtu { + // tty: std::path::PathBuf, + tty: String, + baud_rate: u32, + + #[serde(default = "default_modbus_data_bits")] + data_bits: tokio_serial::DataBits, // TODO: allow this to be represented as a number instead of string + + #[serde(default = "default_modbus_stop_bits")] + stop_bits: tokio_serial::StopBits, // TODO: allow this to be represented as a number instead of string + + #[serde(default = "default_modbus_flow_control")] + flow_control: tokio_serial::FlowControl, + + #[serde(default = "default_modbus_parity")] + parity: tokio_serial::Parity, + }, + #[cfg(feature = "winet-s")] + #[serde(rename = "winet-s")] + SungrowWiNetS { host: String }, + + // Predominantly for if the binary is compiled with no default features for some reason. + #[serde(other)] + Unknown, +} + +impl ModbusProto { + // Can we use the "slave context" thing in Modbus to pass the unit later? + pub async fn connect(&self, unit: modbus::Unit) -> crate::Result { + let client = match *self { + #[cfg(feature = "winet-s")] + ModbusProto::SungrowWiNetS { ref host } => { + tokio_modbus_winets::connect_slave(host, unit).await? + } + + #[cfg(feature = "tcp")] + ModbusProto::Tcp { ref host, port } => { + let socket_addr = format!("{}:{}", host, port).parse()?; + tcp::connect_slave(socket_addr, unit).await? + } + + #[cfg(feature = "rtu")] + ModbusProto::Rtu { + ref tty, + baud_rate, + data_bits, + stop_bits, + flow_control, + parity, + } => { + let builder = tokio_serial::new(tty, baud_rate) + .data_bits(data_bits) + .flow_control(flow_control) + .parity(parity) + .stop_bits(stop_bits); + let port = tokio_serial::SerialStream::open(&builder)?; + rtu::connect_slave(port, unit).await? + } + + ModbusProto::Unknown => { + error!("Unrecognised protocol"); + Err(Error::UnrecognisedModbusProtocol)? + } + }; + Ok(client) + } +} + +pub(crate) fn default_modbus_port() -> u16 { + 502 +} + +#[cfg(feature = "rtu")] +pub(crate) fn default_modbus_data_bits() -> tokio_serial::DataBits { + tokio_serial::DataBits::Eight +} + +#[cfg(feature = "rtu")] +pub(crate) fn default_modbus_stop_bits() -> tokio_serial::StopBits { + tokio_serial::StopBits::One +} + +#[cfg(feature = "rtu")] +pub(crate) fn default_modbus_flow_control() -> tokio_serial::FlowControl { + tokio_serial::FlowControl::None +} + +#[cfg(feature = "rtu")] +pub(crate) fn default_modbus_parity() -> tokio_serial::Parity { + tokio_serial::Parity::None +} + +#[test] +fn parse_minimal_tcp_connect_config() { + use serde_json::json; + let result = serde_json::from_value::(json!({ + "proto": "tcp", + "host": "1.1.1.1" + })); + + let connect = result.unwrap(); + assert!(matches!( + connect.settings, + ModbusProto::Tcp { + ref host, + port: 502 + } if host == "1.1.1.1" + )) +} + +#[test] +fn parse_full_tcp_connect_config() { + use serde_json::json; + let _ = serde_json::from_value::(json!({ + "proto": "tcp", + "host": "10.10.10.219", + "unit": 1, + "address_offset": -1, + "input": [ + { + "address": 5017, + "type": "u32", + "name": "dc_power", + "swap_words": false, + "period": "3s" + }, + { + "address": 5008, + "type": "s16", + "name": "internal_temperature", + "period": "1m" + }, + { + "address": 13008, + "type": "s32", + "name": "load_power", + "swap_words": false, + "period": "3s" + }, + { + "address": 13010, + "type": "s32", + "name": "export_power", + "swap_words": false, + "period": "3s" + }, + { + "address": 13022, + "name": "battery_power", + "period": "3s" + }, + { + "address": 13023, + "name": "battery_level", + "period": "1m" + }, + { + "address": 13024, + "name": "battery_health", + "period": "10m" + } + ], + "hold": [ + { + "address": 13058, + "name": "max_soc", + "period": "90s" + }, + { + "address": 13059, + "name": "min_soc", + "period": "90s" + } + ] + })) + .unwrap(); +} + +#[test] +fn parse_minimal_rtu_connect_config() { + use serde_json::json; + let result = serde_json::from_value::(json!({ + "proto": "rtu", + "tty": "/dev/ttyUSB0", + "baud_rate": 9600, + })); + + let connect = result.unwrap(); + use tokio_serial::*; + assert!(matches!( + connect.settings, + ModbusProto::Rtu { + ref tty, + baud_rate: 9600, + data_bits: DataBits::Eight, + stop_bits: StopBits::One, + flow_control: FlowControl::None, + parity: Parity::None, + .. + } if tty == "/dev/ttyUSB0" + )) +} + +#[test] +fn parse_complete_rtu_connect_config() { + use serde_json::json; + let result = serde_json::from_value::(json!({ + "proto": "rtu", + "tty": "/dev/ttyUSB0", + "baud_rate": 12800, + + // TODO: make lowercase words work + "data_bits": "Seven", // TODO: make 7 work + "stop_bits": "Two", // TODO: make 2 work + "flow_control": "Software", + "parity": "Even", + })); + + let connect = result.unwrap(); + use tokio_serial::*; + assert!(matches!( + connect.settings, + ModbusProto::Rtu { + ref tty, + baud_rate: 12800, + data_bits: DataBits::Seven, + stop_bits: StopBits::Two, + flow_control: FlowControl::Software, + parity: Parity::Even, + .. + } if tty == "/dev/ttyUSB0" + ),); +} diff --git a/modbus-mqtt/src/modbus/connector.rs b/modbus-mqtt/src/modbus/connector.rs new file mode 100644 index 0000000..712c71f --- /dev/null +++ b/modbus-mqtt/src/modbus/connector.rs @@ -0,0 +1,133 @@ +use crate::modbus::{connection, register}; +use crate::mqtt::{Payload, Scopable}; +use crate::{mqtt, shutdown::Shutdown}; +use serde::Deserialize; +use serde_json::value::RawValue as RawJSON; +use tokio::select; +use tracing::{debug, error, info}; + +/* +NOTE: Should this be a connection _registry_ of sorts which also restarts connections which die? +*/ + +/// The topic filter under the prefix to look for connection configs +const TOPIC: &str = "+/connect"; + +/// Responsible for monitoring MQTT topic for connection configs +pub struct Connector { + mqtt: mqtt::Handle, + shutdown: Shutdown, + // connections: Vec, +} + +pub(crate) fn new(mqtt: mqtt::Handle, shutdown: Shutdown) -> Connector { + Connector { + mqtt, + shutdown, + // connections: vec![], + } +} + +impl Connector { + pub async fn run(&mut self) -> crate::Result<()> { + let mut new_connection = self.mqtt.subscribe(TOPIC).await?; + + loop { + select! { + Some(Payload { bytes, topic }) = new_connection.recv() => { + // `unwrap()` is safe here because of the shape of valid topics and the fact that we are subcribed + // to a topic under a prefix. + let connection_id = topic.rsplit('/').nth_back(1).unwrap(); + let mqtt = self.mqtt.scoped(connection_id); + + debug!(?connection_id, ?bytes, ?topic, "Received connection config"); + + if let Err(error) = parse_and_connect(bytes, mqtt, self.shutdown.clone()).await { + error!(?connection_id, ?error, "Error creating connection"); + } + + }, + + _ = self.shutdown.recv() => { + info!("shutting down connector"); + break; + }, + } + } + + Ok(()) + } +} + +async fn parse_and_connect( + bytes: bytes::Bytes, + mqtt: mqtt::Handle, + shutdown: Shutdown, +) -> crate::Result<()> { + match serde_json::from_slice(&bytes) { + Err(_) => mqtt.publish("state", "invalid").await?, + Ok(Config { + connection: + connection::Config { + settings: connection::ModbusProto::Unknown, + .. + }, + .. + }) => mqtt.publish("state", "unknown_proto").await?, + Ok(config) => { + debug!(?config); + connect(config, mqtt, shutdown).await?; + } + } + Ok(()) +} +async fn connect(config: Config<'_>, mqtt: mqtt::Handle, shutdown: Shutdown) -> crate::Result<()> { + if shutdown.is_shutdown() { + return Ok(()); + } + + let Config { + connection: settings, + input, + holding, + } = config; + + let _ = connection::run(settings, mqtt.clone(), shutdown).await?; + + // TODO: consider waiting 1 second before sending the registers to MQTT, to ensure that the connection is listening. + + for (reg_type, registers) in [("holding", holding), ("input", input)] { + let mqtt = mqtt.scoped(reg_type); + for reg in registers { + if let Ok(r) = + serde_json::from_slice::(reg.get().as_bytes()) + { + let json = serde_json::to_vec(&r.register).unwrap(); // unwrap() should be fine because we JUST deserialized it successfully + mqtt.publish(r.address.to_string(), json).await?; + // if let Some(name) = r.register.name { + // r.register.name = None; + // let json = serde_json::to_vec(&r).unwrap(); // unwrap() should be fine because we JUST deserialized it successfully + // mqtt.publish(name, json).await?; + // } + } + } + } + + Ok(()) +} + +/// Wrapper around `modbus::connection::Config` that can include some registers inline, which the connector will +/// re-publish to the appropriate topic once the connection is established. +#[derive(Debug, Deserialize)] +struct Config<'a> { + #[serde(flatten)] + connection: connection::Config, + + // Allow registers to be defined inline, but capture them as raw JSON so that if they have incorrect schema, we can + // still establish the Modbus connection. Valid registers will be re-emitted as individual register configs to MQTT, + // to be picked up by the connection. + #[serde(default, borrow)] + pub input: Vec<&'a RawJSON>, + #[serde(alias = "hold", default, borrow)] + pub holding: Vec<&'a RawJSON>, +} diff --git a/modbus-mqtt/src/modbus/mod.rs b/modbus-mqtt/src/modbus/mod.rs index 644cc08..7d7e8c7 100644 --- a/modbus-mqtt/src/modbus/mod.rs +++ b/modbus-mqtt/src/modbus/mod.rs @@ -1,151 +1,10 @@ -use rust_decimal::{prelude::FromPrimitive, Decimal}; -use serde::Serialize; +pub mod connection; +pub mod connector; +pub mod register; -use self::config::{Register, RegisterValueType}; +pub use connection::Handle; -pub mod config; - -#[derive(Serialize)] -#[serde(rename_all = "lowercase")] -pub enum ConnectState { - Connected, - Disconnected, - Errored, -} - -#[derive(Serialize)] -pub struct ConnectStatus { - #[serde(flatten)] - pub connect: config::Connect, - pub status: ConnectState, -} +type Word = u16; pub type UnitId = tokio_modbus::prelude::SlaveId; pub type Unit = tokio_modbus::prelude::Slave; - -impl RegisterValueType { - pub fn parse_words(&self, words: &[u16]) -> serde_json::Value { - use self::config::RegisterValueType as T; - use self::config::{RegisterArray, RegisterNumeric as N, RegisterString}; - use serde_json::json; - - let bytes: Vec = words.iter().flat_map(|v| v.to_ne_bytes()).collect(); - - match *self { - T::Numeric { ref of, ref adjust } => { - use rust_decimal::MathematicalOps; - let scale: Decimal = Decimal::TEN.powi(adjust.scale.into()).normalize(); - let offset = Decimal::from(adjust.offset); - match of { - N::U8 => json!(scale * Decimal::from(bytes[1]) + offset), // or is it 0? - N::U16 => json!(scale * Decimal::from(words[0]) + offset), - N::U32 => { - json!(bytes - .try_into() - .map(|bytes| scale * Decimal::from(u32::from_le_bytes(bytes)) + offset) - .ok()) - } - N::U64 => { - json!(bytes - .try_into() - .map(|bytes| scale * Decimal::from(u64::from_le_bytes(bytes)) + offset) - .ok()) - } - N::I8 => { - json!(vec![bytes[1]] - .try_into() - .map(|bytes| scale * Decimal::from(i8::from_le_bytes(bytes)) + offset) - .ok()) - } - N::I16 => { - json!(bytes - .try_into() - .map(|bytes| scale * Decimal::from(i16::from_le_bytes(bytes)) + offset) - .ok()) - } - N::I32 => { - json!(bytes - .try_into() - .map(|bytes| scale * Decimal::from(i32::from_le_bytes(bytes)) + offset) - .ok()) - } - N::I64 => { - json!(bytes - .try_into() - .map(|bytes| scale * Decimal::from(i64::from_le_bytes(bytes)) + offset) - .ok()) - } - N::F32 => { - json!(bytes - .try_into() - .map(|bytes| scale - * Decimal::from_f32(f32::from_le_bytes(bytes)).unwrap() - + offset) - .ok()) - } - N::F64 => { - json!(bytes - .try_into() - .map(|bytes| scale - * Decimal::from_f64(f64::from_le_bytes(bytes)).unwrap() - + offset) - .ok()) - } - } - } - T::String(RegisterString { .. }) => { - json!(String::from_utf16_lossy(words)) - } - T::Array(RegisterArray { .. }) => todo!(), - } - } -} - -impl Register { - pub fn parse_words(&self, words: &[u16]) -> serde_json::Value { - self.parse.value_type.parse_words(words) - } - - pub fn apply_swaps(&self, words: &[u16]) -> Vec { - let words: Vec = if self.parse.swap_bytes.0 { - words.iter().map(|v| v.swap_bytes()).collect() - } else { - words.into() - }; - - if self.parse.swap_words.0 { - words - .chunks_exact(2) - .flat_map(|chunk| vec![chunk[1], chunk[0]]) - .collect() - } else { - words - } - } -} -#[cfg(test)] -use pretty_assertions::assert_eq; -#[test] -fn test_parse_1() { - use self::config::{RegisterParse, Swap}; - use serde_json::json; - - let reg = Register { - address: 42, - name: None, - interval: Default::default(), - parse: RegisterParse { - swap_bytes: Swap(false), - swap_words: Swap(false), - value_type: RegisterValueType::Numeric { - of: config::RegisterNumeric::I32, - adjust: config::RegisterNumericAdjustment { - scale: 0, - offset: 0, - }, - }, - }, - }; - - assert_eq!(reg.parse_words(&[843, 0]), json!(843)); -} diff --git a/modbus-mqtt/src/modbus/register.rs b/modbus-mqtt/src/modbus/register.rs new file mode 100644 index 0000000..1841165 --- /dev/null +++ b/modbus-mqtt/src/modbus/register.rs @@ -0,0 +1,550 @@ +use super::Word; +use serde::{Deserialize, Serialize}; +use std::time::Duration; +use tokio::{ + select, + sync::mpsc, + time::{interval, MissedTickBehavior}, +}; +use tracing::{debug, warn}; + +pub struct Monitor { + mqtt: mqtt::Handle, + modbus: super::Handle, + address: u16, + register: Register, + register_type: RegisterType, +} + +impl Monitor { + pub fn new( + register: Register, + register_type: RegisterType, + address: u16, + mqtt: mqtt::Handle, + modbus: super::Handle, + ) -> Monitor { + Monitor { + mqtt, + register_type, + register, + address, + modbus, + } + } + + pub async fn run(self) { + tokio::spawn(async move { + let mut interval = interval(self.register.interval); + interval.set_missed_tick_behavior(MissedTickBehavior::Delay); + + loop { + interval.tick().await; + if let Ok(words) = self.read().await { + debug!(address=%self.address, "type"=?self.register_type, ?words); + + #[cfg(debug_assertions)] + self.mqtt + .publish("raw", serde_json::to_vec(&words).unwrap()) + .await + .unwrap(); + + let value = self.register.parse_words(&words); + + self.mqtt + .publish("state", serde_json::to_vec(&value).unwrap()) + .await + .unwrap(); + } + } + }); + } + + async fn read(&self) -> crate::Result> { + match self.register_type { + RegisterType::Input => { + self.modbus + .read_input_register(self.address, self.register.size()) + .await + } + RegisterType::Holding => { + self.modbus + .read_holding_register(self.address, self.register.size()) + .await + } + } + } +} + +pub(crate) async fn subscribe( + mqtt: &mqtt::Handle, +) -> crate::Result> { + let (tx, rx) = mpsc::channel(8); + let mut input_registers = mqtt.subscribe("input/+").await?; + let mut holding_registers = mqtt.subscribe("holding/+").await?; + + tokio::spawn(async move { + fn to_register(payload: &Payload) -> crate::Result { + let Payload { bytes, topic } = payload; + let address = topic + .rsplit('/') + .next() + .expect("subscribed topic guarantees we have a last segment") + .parse()?; + Ok(AddressedRegister { + address, + register: serde_json::from_slice(bytes)?, + }) + } + + loop { + select! { + Some(ref payload) = input_registers.recv() => { + match to_register(payload) { + Ok(register) => if (tx.send((RegisterType::Input, register)).await).is_err() { break; }, + Err(error) => warn!(?error, def=?payload.bytes, "ignoring invalid input register definition"), + } + }, + Some(ref payload) = holding_registers.recv() => { + match to_register(payload) { + Ok(register) => if (tx.send((RegisterType::Holding, register)).await).is_err() { break; }, + Err(error) => warn!(?error, def=?payload.bytes, "ignoring invalid holding register definition"), + } + } + } + } + }); + + Ok(rx) +} + +#[derive(Clone, Copy, Debug)] +pub enum RegisterType { + Input, + Holding, +} + +#[derive(Clone, Default, Debug, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "lowercase", default)] +pub struct RegisterNumericAdjustment { + pub scale: i8, // powers of 10 (0 = no adjustment, 1 = x10, -1 = /10) + pub offset: i8, + // precision: Option, +} + +#[derive(Clone, Default, Debug, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum RegisterNumeric { + U8, + #[default] + U16, + U32, + U64, + + #[serde(alias = "s8")] + I8, + #[serde(alias = "s16")] + I16, + #[serde(alias = "s32")] + I32, + #[serde(alias = "s64")] + I64, + + F32, + F64, +} + +impl RegisterNumeric { + // Modbus limits sequential reads to 125 apparently, so 8-bit should be fine - https://github.com/slowtec/tokio-modbus/issues/112#issuecomment-1095316069= + fn size(&self) -> u8 { + use RegisterNumeric::*; + // Each Modbus register holds 16-bits, so count is half what the byte count would be + match self { + U8 | I8 => 1, + U16 | I16 => 1, + U32 | I32 | F32 => 2, + U64 | I64 | F64 => 4, + } + } + + fn type_name(&self) -> String { + format!("{:?}", *self).to_lowercase() + } +} + +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +#[serde(tag = "type", rename = "string")] +pub struct RegisterString { + length: u8, +} + +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +#[serde(tag = "type", rename = "array")] +pub struct RegisterArray { + count: u8, + + #[serde(default)] + of: RegisterNumeric, + + // Arrays are only of numeric types, so we can apply an adjustment here + #[serde(flatten, skip_serializing_if = "IsDefault::is_default")] + adjust: RegisterNumericAdjustment, +} + +impl Default for RegisterArray { + fn default() -> Self { + Self { + count: 1, + of: Default::default(), + adjust: Default::default(), + } + } +} + +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +#[serde(untagged)] +pub enum RegisterValueType { + Numeric { + #[serde(rename = "type", default)] + of: RegisterNumeric, + + #[serde(flatten, skip_serializing_if = "IsDefault::is_default")] + adjust: RegisterNumericAdjustment, + }, + Array(RegisterArray), + String(RegisterString), +} + +impl RegisterValueType { + pub fn type_name(&self) -> String { + match *self { + RegisterValueType::Numeric { ref of, .. } => of.type_name(), + RegisterValueType::Array(_) => "array".to_owned(), + RegisterValueType::String(_) => "string".to_owned(), + } + } +} + +impl Default for RegisterValueType { + fn default() -> Self { + RegisterValueType::Numeric { + of: Default::default(), + adjust: Default::default(), + } + } +} + +impl RegisterValueType { + // Modbus limits sequential reads to 125 apparently, so 8-bit should be fine - https://github.com/slowtec/tokio-modbus/issues/112#issuecomment-1095316069= + pub fn size(&self) -> u8 { + use RegisterValueType::*; + + match self { + Numeric { of, .. } => of.size(), + String(RegisterString { length }) => *length, + Array(RegisterArray { of, count, .. }) => of.size() * count, + } + } +} + +#[derive(Clone, Default, Debug, PartialEq, Eq, Serialize, Deserialize)] +#[serde(transparent)] +pub struct Swap(pub bool); + +trait IsDefault { + fn is_default(&self) -> bool; +} +impl IsDefault for T +where + T: Default + PartialEq, +{ + fn is_default(&self) -> bool { + *self == Default::default() + } +} + +#[derive(Clone, Debug, PartialEq, Eq, Default, Serialize, Deserialize)] +pub struct RegisterParse { + #[serde(default, skip_serializing_if = "IsDefault::is_default")] + pub swap_bytes: Swap, + + #[serde(default, skip_serializing_if = "IsDefault::is_default")] + pub swap_words: Swap, + + #[serde(flatten, skip_serializing_if = "IsDefault::is_default")] + pub value_type: RegisterValueType, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct Register { + #[serde(skip_serializing_if = "Option::is_none")] + pub name: Option, + + #[serde(flatten, default, skip_serializing_if = "IsDefault::is_default")] + pub parse: RegisterParse, + + #[serde( + with = "humantime_serde", + default = "default_register_interval", + alias = "period", + alias = "duration" + )] + pub interval: Duration, +} +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct AddressedRegister { + pub address: u16, + + #[serde(flatten)] + pub register: Register, +} + +fn default_register_interval() -> Duration { + Duration::from_secs(60) +} + +#[test] +fn parse_empty_register_parser_defaults() { + use serde_json::json; + let empty = serde_json::from_value::(json!({})); + assert!(matches!( + empty.unwrap(), + RegisterParse { + swap_bytes: Swap(false), + swap_words: Swap(false), + value_type: RegisterValueType::Numeric { + of: RegisterNumeric::U16, + adjust: RegisterNumericAdjustment { + scale: 0, + offset: 0, + } + } + } + )); +} + +#[test] +fn parse_register_parser_type() { + use serde_json::json; + let result = serde_json::from_value::(json!({ + "type": "s32" + })); + assert!(matches!( + result.unwrap().value_type, + RegisterValueType::Numeric { + of: RegisterNumeric::I32, + .. + } + )); +} + +#[test] +fn parse_register_parser_array() { + use serde_json::json; + let result = serde_json::from_value::(json!({ + "type": "array", + "of": "s32", + "count": 10, + })); + let payload = result.unwrap(); + // println!("{:?}", payload); + // println!("{}", serde_json::to_string_pretty(&payload).unwrap()); + + assert!(matches!( + payload.value_type, + RegisterValueType::Array(RegisterArray { + of: RegisterNumeric::I32, + count: 10, + .. + }) + )); +} + +#[test] +fn parse_register_parser_array_implicit_u16() { + use serde_json::json; + let result = serde_json::from_value::(json!({ + "type": "array", + "count": 10, + })); + let payload = result.unwrap(); + // println!("{:?}", payload); + // println!("{}", serde_json::to_string_pretty(&payload).unwrap()); + + assert!(matches!( + payload.value_type, + RegisterValueType::Array(RegisterArray { + of: RegisterNumeric::U16, + count: 10, + .. + }) + )); +} + +#[test] +fn parse_register_parser_string() { + use serde_json::json; + let result = serde_json::from_value::(json!({ + "type": "string", + "length": 10, + })); + let payload = result.unwrap(); + // println!("{:?}", payload); + // println!("{}", serde_json::to_string_pretty(&payload).unwrap()); + + assert!(matches!( + payload.value_type, + RegisterValueType::String(RegisterString { length: 10, .. }) + )); +} + +#[test] +fn parse_register_parser_scale_etc() { + use serde_json::json; + let result = serde_json::from_value::(json!({ + "type": "s32", + "scale": -1, + "offset": 20, + })); + assert!(matches!( + result.unwrap().value_type, + RegisterValueType::Numeric { + of: RegisterNumeric::I32, + adjust: RegisterNumericAdjustment { + scale: -1, + offset: 20 + } + } + )); +} + +impl RegisterValueType { + pub fn parse_words(&self, words: &[u16]) -> serde_json::Value { + use self::RegisterNumeric as N; + use rust_decimal::{prelude::FromPrimitive, Decimal, MathematicalOps}; + use serde_json::json; + use RegisterValueType as T; + + let bytes: Vec = words.iter().flat_map(|v| v.to_be_bytes()).collect(); + + match *self { + T::Numeric { ref of, ref adjust } => { + let scale: Decimal = Decimal::TEN.powi(adjust.scale.into()).normalize(); + let offset = Decimal::from(adjust.offset); + match of { + N::U8 => json!(scale * Decimal::from(bytes[1]) + offset), // or is it 0? + N::U16 => json!(scale * Decimal::from(words[0]) + offset), + N::U32 => { + json!(bytes + .try_into() + .map(|bytes| scale * Decimal::from(u32::from_be_bytes(bytes)) + offset) + .ok()) + } + N::U64 => { + json!(bytes + .try_into() + .map(|bytes| scale * Decimal::from(u64::from_be_bytes(bytes)) + offset) + .ok()) + } + N::I8 => { + json!(vec![bytes[1]] + .try_into() + .map(|bytes| scale * Decimal::from(i8::from_be_bytes(bytes)) + offset) + .ok()) + } + N::I16 => { + json!(bytes + .try_into() + .map(|bytes| scale * Decimal::from(i16::from_be_bytes(bytes)) + offset) + .ok()) + } + N::I32 => { + json!(bytes + .try_into() + .map(|bytes| scale * Decimal::from(i32::from_be_bytes(bytes)) + offset) + .ok()) + } + N::I64 => { + json!(bytes + .try_into() + .map(|bytes| scale * Decimal::from(i64::from_be_bytes(bytes)) + offset) + .ok()) + } + N::F32 => { + json!(bytes + .try_into() + .map(|bytes| scale + * Decimal::from_f32(f32::from_be_bytes(bytes)).unwrap() + + offset) + .ok()) + } + N::F64 => { + json!(bytes + .try_into() + .map(|bytes| scale + * Decimal::from_f64(f64::from_be_bytes(bytes)).unwrap() + + offset) + .ok()) + } + } + } + T::String(RegisterString { .. }) => { + json!(String::from_utf16_lossy(words)) + } + T::Array(RegisterArray { .. }) => todo!(), + } + } +} + +impl Register { + pub fn size(&self) -> u8 { + self.parse.value_type.size() + } + + pub fn parse_words(&self, words: &[u16]) -> serde_json::Value { + self.parse.value_type.parse_words(&self.apply_swaps(words)) + } + + fn apply_swaps(&self, words: &[u16]) -> Vec { + let words: Vec = if self.parse.swap_bytes.0 { + words.iter().map(|v| v.swap_bytes()).collect() + } else { + words.into() + }; + + if self.parse.swap_words.0 { + words + .chunks_exact(2) + .flat_map(|chunk| vec![chunk[1], chunk[0]]) + .collect() + } else { + words + } + } +} +#[cfg(test)] +use pretty_assertions::assert_eq; + +use crate::mqtt::{self, Payload}; +#[test] +fn test_parse_1() { + use serde_json::json; + + let reg = Register { + name: None, + interval: Default::default(), + parse: RegisterParse { + swap_bytes: Swap(false), + swap_words: Swap(true), + value_type: RegisterValueType::Numeric { + of: RegisterNumeric::U32, + adjust: RegisterNumericAdjustment { + scale: 0, + offset: 0, + }, + }, + }, + }; + + assert_eq!(reg.parse_words(&[843, 0]), json!(843)); +} diff --git a/modbus-mqtt/src/mqtt.rs b/modbus-mqtt/src/mqtt.rs new file mode 100644 index 0000000..857a45e --- /dev/null +++ b/modbus-mqtt/src/mqtt.rs @@ -0,0 +1,299 @@ +use std::collections::HashMap; + +use bytes::Bytes; +use rumqttc::{ + mqttbytes::matches as matches_topic, AsyncClient, Event, EventLoop, MqttOptions, Publish, + Subscribe, SubscribeFilter, +}; +use tokio::{ + select, + sync::mpsc::{self, channel, Receiver, Sender}, +}; +use tracing::{debug, info, warn}; + +#[derive(Debug)] +pub struct Payload { + pub bytes: Bytes, + pub topic: String, +} + +#[derive(Debug, Clone)] +pub enum Message { + Subscribe(Subscribe, Sender), + Publish(Publish), + Shutdown, +} + +pub(crate) async fn new(options: MqttOptions) -> Connection { + let (client, event_loop) = AsyncClient::new(options, 32); + + let (tx, rx) = channel(32); + Connection { + client, + event_loop, + subscriptions: HashMap::new(), + tx, + rx, + } +} + +// Maintain internal subscriptions as well as MQTT subscriptions. Relay all received messages on MQTT subscribed topics +// to internal components who have a matching topic. Unsubscribe topics when no one is listening anymore. +pub(crate) struct Connection { + subscriptions: HashMap>>, + tx: Sender, + rx: Receiver, + client: AsyncClient, + event_loop: EventLoop, +} + +impl Connection { + pub async fn run(&mut self) -> crate::Result<()> { + loop { + select! { + event = self.event_loop.poll() => { + self.handle_event(event?).await? + } + request = self.rx.recv() => { + match request { + None => return Ok(()), + Some(Message::Shutdown) => { + info!("MQTT connection shutting down"); + break; + } + Some(req) => self.handle_request(req).await?, + } + } + } + } + + Ok(()) + } + + pub fn handle(&self) -> Handle { + Handle { + prefix: None, + tx: self.tx.clone(), + } + } + + async fn handle_event(&mut self, event: Event) -> crate::Result<()> { + use rumqttc::Incoming; + + #[allow(clippy::single_match)] + match event { + Event::Incoming(Incoming::Publish(Publish { topic, payload, .. })) => { + debug!(%topic, ?payload, "publish"); + self.handle_data(topic, payload).await?; + } + // e => debug!(event = ?e), + _ => {} + } + + Ok(()) + } + + #[tracing::instrument(level = "debug", skip(self), fields(subscriptions = ?self.subscriptions.keys()))] + async fn handle_data(&mut self, topic: String, bytes: Bytes) -> crate::Result<()> { + let mut targets = vec![]; + + // Remove subscriptions whose channels are closed, adding matching channels to the `targets` vec. + self.subscriptions.retain(|filter, channels| { + if matches_topic(&topic, filter) { + channels.retain(|channel| { + if channel.is_closed() { + warn!(?channel, "closed"); + false + } else { + targets.push(channel.clone()); + true + } + }); + !channels.is_empty() + } else { + true + } + }); + + for target in targets { + if target + .send(Payload { + topic: topic.clone(), + bytes: bytes.clone(), + }) + .await + .is_err() + { + // These will be removed above next time a matching payload is removed + } + } + Ok(()) + } + + async fn handle_request(&mut self, request: Message) -> crate::Result<()> { + match request { + Message::Publish(Publish { + topic, + payload, + qos, + retain, + .. + }) => { + self.client + .publish_bytes(topic, qos, retain, payload) + .await? + } + Message::Subscribe(Subscribe { filters, .. }, channel) => { + for filter in &filters { + let channel = channel.clone(); + + // NOTE: Curently allows multiple components to watch the same topic filter, but if there is no need + // for this, it might make more sense to have it _replace_ the channel, so that old (stale) + // components automatically finish running. + match self.subscriptions.get_mut(&filter.path) { + Some(channels) => channels.push(channel), + None => { + self.subscriptions + .insert(filter.path.clone(), vec![channel]); + } + } + } + + self.client.subscribe_many(filters).await? + } + Message::Shutdown => panic!("Handled by the caller"), + } + Ok(()) + } +} + +#[derive(Clone)] +pub struct Handle { + prefix: Option, + tx: Sender, +} + +// IDEA: make subscribe+publish _generic_ over the payload type, as long as it implements a Payload trait we define, +// which allows them to perform the serialization/deserialization to Bytes. For most domain types, the trait would be +// implemented to use serde_json but for Bytes and Vec it would just return itself. +// The return values may need to be crate::Result> or crate::Result>>. +impl Handle { + pub async fn subscribe>(&self, topic: S) -> crate::Result> { + let (tx_bytes, rx) = mpsc::channel(8); + let mut msg = + Message::Subscribe(Subscribe::new(topic, rumqttc::QoS::AtLeastOnce), tx_bytes); + if let Some(prefix) = &self.prefix { + msg = msg.scoped(prefix.to_owned()); + } + self.tx + .send(msg) + .await + .map_err(|_| crate::Error::SendError)?; + Ok(rx) + } + pub async fn publish, B: Into>( + &self, + topic: S, + payload: B, + ) -> crate::Result<()> { + let mut msg = Message::Publish(Publish::new( + topic, + rumqttc::QoS::AtLeastOnce, + payload.into(), + )); + if let Some(prefix) = &self.prefix { + msg = msg.scoped(prefix.to_owned()); + } + self.tx + .send(msg) + .await + .map_err(|_| crate::Error::SendError)?; + Ok(()) + } + + pub async fn shutdown(self) -> crate::Result<()> { + self.tx + .send(Message::Shutdown) + .await + .map_err(|_| crate::Error::SendError) + } +} + +pub(crate) trait Scopable { + fn scoped>(&self, prefix: S) -> Self; +} + +// FIXME: this doesn't actually _prefix_ it _appends_ to the existing prefix, so there's probably a better name for this +// trait, like: Scopable +impl Scopable for Handle { + fn scoped>(&self, prefix: S) -> Self { + match self { + Self { prefix: None, tx } => Self { + prefix: Some(prefix.into()), + tx: tx.clone(), + }, + Self { + prefix: Some(existing), + tx, + } => Self { + prefix: Some(format!("{}/{}", existing, prefix.into())), + tx: tx.clone(), + }, + } + } +} + +impl Scopable for Message { + fn scoped>(&self, prefix: S) -> Self { + match self { + Message::Subscribe(sub, bytes) => Message::Subscribe(sub.scoped(prefix), bytes.clone()), + Message::Publish(publish) => Message::Publish(publish.scoped(prefix)), + other => (*other).clone(), + } + } +} + +impl Scopable for Subscribe { + fn scoped>(&self, prefix: S) -> Self { + let prefix: String = prefix.into(); + Self { + pkid: self.pkid, + filters: self + .filters + .iter() + .map(|f| f.clone().scoped(prefix.clone())) + .collect(), + } + } +} + +impl Scopable for Publish { + fn scoped>(&self, prefix: S) -> Self { + let mut prefixed = self.clone(); + prefixed.topic = format!("{}/{}", prefix.into(), &self.topic); + prefixed + } +} + +impl Scopable for SubscribeFilter { + fn scoped>(&self, prefix: S) -> Self { + SubscribeFilter { + path: format!("{}/{}", prefix.into(), &self.path), + qos: self.qos, + } + } +} + +impl From for Bytes { + fn from(payload: Payload) -> Self { + payload.bytes + } +} + +impl std::ops::Deref for Payload { + type Target = Bytes; + + fn deref(&self) -> &Self::Target { + &self.bytes + } +} diff --git a/modbus-mqtt/src/server.rs b/modbus-mqtt/src/server.rs new file mode 100644 index 0000000..24d0cbe --- /dev/null +++ b/modbus-mqtt/src/server.rs @@ -0,0 +1,57 @@ +use crate::{ + modbus, + mqtt::{self, Scopable}, +}; + +use rumqttc::MqttOptions; +use std::future::Future; +use tokio::sync::{broadcast, mpsc}; +use tracing::error; + +pub async fn run + Send>( + prefix: P, + mut mqtt_options: MqttOptions, + shutdown: impl Future, +) -> crate::Result<()> { + let prefix = prefix.into(); + + let (notify_shutdown, _) = broadcast::channel(1); + let (shutdown_complete_tx, mut shutdown_complete_rx) = mpsc::channel(1); + + mqtt_options.set_last_will(rumqttc::LastWill { + topic: prefix.clone(), + message: "offline".into(), + qos: rumqttc::QoS::AtMostOnce, + retain: false, + }); + let mut mqtt_connection = mqtt::new(mqtt_options).await; + let mqtt = mqtt_connection.handle(); + mqtt.publish(prefix.clone(), "online").await?; + + let mut connector = modbus::connector::new( + mqtt.scoped(prefix), + (notify_shutdown.subscribe(), shutdown_complete_tx.clone()).into(), + ); + + tokio::spawn(async move { + if let Err(err) = mqtt_connection.run().await { + error!(cause = %err, "MQTT connection error"); + } + }); + + tokio::spawn(async move { + if let Err(err) = connector.run().await { + error!(cause = %err, "Modbus connector error"); + } + }); + + shutdown.await; + drop(notify_shutdown); + drop(shutdown_complete_tx); + + // We want MQTT to be the last thing to shutdown, so it gets shutdown after everything else + shutdown_complete_rx.recv().await; + mqtt.shutdown().await?; + + Ok(()) +} diff --git a/modbus-mqtt/src/shutdown.rs b/modbus-mqtt/src/shutdown.rs new file mode 100644 index 0000000..640b221 --- /dev/null +++ b/modbus-mqtt/src/shutdown.rs @@ -0,0 +1,92 @@ +//! **Note**: this is a barely modified copy of the code which appears in mini-redis + +type Notify = tokio::sync::broadcast::Receiver<()>; +type Guard = tokio::sync::mpsc::Sender<()>; +/// Listens for the server shutdown signal. +/// +/// Shutdown is signalled using a `broadcast::Receiver`. Only a single value is +/// ever sent. Once a value has been sent via the broadcast channel, the server +/// should shutdown. +/// +/// The `Shutdown` struct listens for the signal and tracks that the signal has +/// been received. Callers may query for whether the shutdown signal has been +/// received or not. +/// +#[derive(Debug)] +pub(crate) struct Shutdown { + /// `true` if the shutdown signal has been received + shutdown: bool, + + /// The receive half of the channel used to listen for shutdown. + notify: Notify, + + /// Optional guard as a sender so that when the `Shutdown` struct is dropped, the other side of the channel is + /// closed. + guard: Option, +} + +impl Clone for Shutdown { + fn clone(&self) -> Self { + Self { + shutdown: self.shutdown, + notify: self.notify.resubscribe(), + guard: self.guard.clone(), + } + } +} + +impl Shutdown { + /// Create a new `Shutdown` backed by the given `broadcast::Receiver`. + pub(crate) fn new(notify: Notify) -> Shutdown { + Shutdown { + shutdown: false, + notify, + guard: None, + } + } + /// Create a new `Shutdown` backed by the given `broadcast::Receiver` with a given guard. + pub(crate) fn with_guard(notify: Notify, guard: Guard) -> Shutdown { + Shutdown { + shutdown: false, + notify, + guard: Some(guard), + } + } + + /// Returns `true` if the shutdown signal has been received. + pub(crate) fn is_shutdown(&self) -> bool { + self.shutdown + } + + /// Receive the shutdown notice, waiting if necessary. + pub(crate) async fn recv(&mut self) { + // If the shutdown signal has already been received, then return + // immediately. + if self.shutdown { + return; + } + + // Cannot receive a "lag error" as only one value is ever sent. + let _ = self.notify.recv().await; + + // Remember that the signal has been received. + self.shutdown = true; + } +} + +impl From for Shutdown { + fn from(notify: Notify) -> Self { + Self::new(notify) + } +} + +impl From<(Notify, Guard)> for Shutdown { + fn from((notify, guard): (Notify, Guard)) -> Self { + Self::with_guard(notify, guard) + } +} +impl From<(Guard, Notify)> for Shutdown { + fn from((guard, notify): (Guard, Notify)) -> Self { + Self::with_guard(notify, guard) + } +} diff --git a/sungrow-winets/src/lib.rs b/sungrow-winets/src/lib.rs index e277c6b..9748d01 100644 --- a/sungrow-winets/src/lib.rs +++ b/sungrow-winets/src/lib.rs @@ -344,6 +344,7 @@ struct Device { dev_type: u8, // unit/slave ID + #[allow(dead_code)] #[serde(deserialize_with = "serde_aux::prelude::deserialize_number_from_string")] phys_addr: u8, // UNUSED: @@ -397,7 +398,7 @@ fn test_deserialize_device() { enum WebSocketMessage { Connect { token: String }, - DeviceList { list: Vec }, + // DeviceList { list: Vec }, // Not yet used: // State, // system state @@ -414,7 +415,7 @@ enum WebSocketMessage { #[derive(Debug, Deserialize)] struct ResultList { - count: u16, + // count: u16, #[serde(rename = "list")] items: Vec, }