From 05a9295957cf961dd61b4c62506cf300caa62ed1 Mon Sep 17 00:00:00 2001 From: Michael Brooks Date: Tue, 21 Jan 2020 15:18:52 -0800 Subject: [PATCH] Use Coral Cloud IoT Core Library * Instead of implementing all the MQTT communication in this example, use the Coral library available in coral-cloud. * This also switches from a SW key (generated before first run) to using the on-board crypto chip. --- edge/README.md | 18 ++- edge/cloud_config.ini | 14 +++ edge/detect_cloudiot.py | 273 +++++++--------------------------------- 3 files changed, 71 insertions(+), 234 deletions(-) create mode 100644 edge/cloud_config.ini diff --git a/edge/README.md b/edge/README.md index e620592..aeeea85 100644 --- a/edge/README.md +++ b/edge/README.md @@ -9,8 +9,13 @@ Also make sure if you can [run a model with camera](https://coral.withgoogle.com 1. Install required libraries ``` -$ pip3 install pyjwt==1.7.1, paho-mqtt, imutils -$ sudo apt-get install python3-cryptography +echo "deb https://packages.cloud.google.com/apt coral-cloud-stable main" | sudo tee /etc/apt/sources.list.d/coral-cloud.list + +curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo apt-key add - + +sudo apt update + +sudo apt install python3-coral-cloudiot ``` 2. Clone this repository @@ -31,8 +36,10 @@ Each device must be registered to IoT Core in order to publish data and receive 1. [Enable IoT Core and Pub/Sub](https://console.cloud.google.com/flows/enableapi?apiid=cloudiot.googleapis.com,pubsub) 2. [Create a device registry](https://cloud.google.com/iot/docs/quickstart#create_a_device_registry) Use `demo1` for Device name, and `demo-topic` for Topic name. -3. [Generate a device key pair](https://cloud.google.com/iot/docs/quickstart#generate_a_device_key_pair) -4. [Add a device to the registry](https://cloud.google.com/iot/docs/quickstart#add_a_device_to_the_registry) +3. [Add a device to the registry](https://cloud.google.com/iot/docs/quickstart#add_a_device_to_the_registry) +To get the ES256 key of the HW crypto run `python3 /usr/lib/python3/dist-packages/coral/cloudiot/ecc608_pubkey.py` +4. Configure the cloud config +Edit the cloud_config.ini in the edge folder to ProjectID, RegistryID, and DeviceID set in step 2. Enable Cloud IoT core by setting Enabled = true. 5. Install root certificate ``` $ wget https://pki.goog/roots.pem @@ -43,9 +50,8 @@ Run the following commands, then open browser. ``` export DEMO_FILES="$HOME/demo_files" -export PROJECTID="YOUR GCP PROJECT ID" python3 detect_cloudiot.py \ - --project_id ${PROJECTID} \ + --cloud_config cloud_config.ini \ --model ${DEMO_FILES}/mobilenet_ssd_v1_coco_quant_postprocess_edgetpu.tflite \ --labels ${DEMO_FILES}/coco_labels.txt \ --threshold 0.4 \ diff --git a/edge/cloud_config.ini b/edge/cloud_config.ini new file mode 100644 index 0000000..e938a2d --- /dev/null +++ b/edge/cloud_config.ini @@ -0,0 +1,14 @@ +[DEFAULT] +Enabled = false +ProjectID = my-project +CloudRegion = us-central1 +RegistryID = demo-registry +DeviceID = demo1 +# CA Certs should be pulled from https://pki.goog/roots.pem +CACerts = roots.pem +MQTTBridgeHostName = mqtt.googleapis.com +MQTTBridgePort = 8883 +# MessageType is expected to always be event. +MessageType = event +# RSA Cert is not required unless SW crypto is used. +RSACertFile = diff --git a/edge/detect_cloudiot.py b/edge/detect_cloudiot.py index 3477568..1939055 100644 --- a/edge/detect_cloudiot.py +++ b/edge/detect_cloudiot.py @@ -1,166 +1,17 @@ """A demo which runs object detection on camera frames. """ -# [START iot_mqtt_includes] -import datetime -import os -import ssl -import time -import jwt -import paho.mqtt.client as mqtt -# [END iot_mqtt_includes] import argparse import collections import colorsys import itertools import time +from coral.cloudiot.core import CloudIot from edgetpu.detection.engine import DetectionEngine from edgetpuvision import svg from edgetpuvision import utils from edgetpuvision.apps import run_app, run_server -# The initial backoff time after a disconnection occurs, in seconds. -minimum_backoff_time = 1 - -# The maximum backoff time before giving up, in seconds. -MAXIMUM_BACKOFF_TIME = 32 - -# Whether to wait with exponential backoff before publishing. -should_backoff = False - -# MQTT settings, change if needed -ALGORITHM = 'RS256' -MQTT_BRIDGE_HOSTNAME = 'mqtt.googleapis.com' -MQTT_BRIDGE_PORT = 8883 -PRIVATE_KEY_FILE = '/home/mendel/rsa_private.pem' -CA_CERTS = '/home/mendel/roots.pem' -JWT_EXPIRES_MINUTES = 20 -MESSAGE_TYPE = 'event' - -# Cloud IoT settings -CLOUD_REGION = 'us-central1' -DEVICE_ID = 'demo1' -REGISTRY_ID = 'demo-registry' - - -# [START iot_mqtt_jwt] -def create_jwt(project_id, private_key_file, algorithm): - """Creates a JWT (https://jwt.io) to establish an MQTT connection. - Args: - project_id: The cloud project ID this device belongs to - private_key_file: A path to a file containing either an RSA256 or - ES256 private key. - algorithm: The encryption algorithm to use. Either 'RS256' or 'ES256' - Returns: - An MQTT generated from the given project_id and private key, which - expires in 20 minutes. After 20 minutes, your client will be - disconnected, and a new JWT will have to be generated. - Raises: - ValueError: If the private_key_file does not contain a known key. - """ - - token = { - # The time that the token was issued at - 'iat': datetime.datetime.utcnow(), - # The time the token expires. - 'exp': datetime.datetime.utcnow() + datetime.timedelta(minutes=60), - # The audience field should always be set to the GCP project id. - 'aud': project_id - } - - # Read the private key file. - with open(private_key_file, 'r') as f: - private_key = f.read() - - print('Creating JWT using {} from private key file {}'.format( - algorithm, private_key_file)) - - return jwt.encode(token, private_key, algorithm=algorithm) -# [END iot_mqtt_jwt] - - -# [START iot_mqtt_config] -def error_str(rc): - """Convert a Paho error to a human readable string.""" - return '{}: {}'.format(rc, mqtt.error_string(rc)) - - -def on_connect(unused_client, unused_userdata, unused_flags, rc): - """Callback for when a device connects.""" - print('on_connect', mqtt.connack_string(rc)) - - # After a successful connect, reset backoff time and stop backing off. - global should_backoff - global minimum_backoff_time - should_backoff = False - minimum_backoff_time = 1 - - -def on_disconnect(unused_client, unused_userdata, rc): - """Paho callback for when a device disconnects.""" - print('on_disconnect', error_str(rc)) - - # Since a disconnect occurred, the next loop iteration will wait with - # exponential backoff. - global should_backoff - should_backoff = True - - -def on_publish(unused_client, unused_userdata, unused_mid): - """Paho callback when a message is sent to the broker.""" - print('on_publish') - - -def on_message(unused_client, unused_userdata, message): - """Callback when the device receives a message on a subscription.""" - payload = str(message.payload) - print('Received message \'{}\' on topic \'{}\' with Qos {}'.format( - payload, message.topic, str(message.qos))) - - -def get_client( - project_id, cloud_region, registry_id, device_id, private_key_file, - algorithm, ca_certs, mqtt_bridge_hostname, mqtt_bridge_port): - """Create our MQTT client. The client_id is a unique string that identifies - this device. For Google Cloud IoT Core, it must be in the format below.""" - client = mqtt.Client( - client_id=('projects/{}/locations/{}/registries/{}/devices/{}' - .format( - project_id, - cloud_region, - registry_id, - device_id))) - - # With Google Cloud IoT Core, the username field is ignored, and the - # password field is used to transmit a JWT to authorize the device. - client.username_pw_set(username='unused', - password=create_jwt( - project_id, private_key_file, algorithm)) - - # Enable SSL/TLS support. - client.tls_set(ca_certs=ca_certs, tls_version=ssl.PROTOCOL_TLSv1_2) - - # Register message callbacks. https://eclipse.org/paho/clients/python/docs/ - # describes additional callbacks that Paho supports. In this example, the - # callbacks just print to standard out. - client.on_connect = on_connect - client.on_publish = on_publish - client.on_disconnect = on_disconnect - client.on_message = on_message - - # Connect to the Google MQTT bridge. - client.connect(mqtt_bridge_hostname, mqtt_bridge_port) - - # This is the topic that the device will receive configuration updates on. - mqtt_config_topic = '/devices/{}/config'.format(device_id) - - # Subscribe to the config topic. - client.subscribe(mqtt_config_topic, qos=1) - - return client -# [END iot_mqtt_config] - - CSS_STYLES = str(svg.CssStyle({'.back': svg.Style(fill='black', stroke='black', stroke_width='0.5em'), @@ -266,17 +117,6 @@ def render_gen(args): global minimum_backoff_time import json import random - # Publish to the events or state topic based on the flag. - sub_topic = 'events' if MESSAGE_TYPE == 'event' else 'state' - - mqtt_topic = '/devices/{}/{}'.format(DEVICE_ID, sub_topic) - - jwt_iat = datetime.datetime.utcnow() - jwt_exp_mins = JWT_EXPIRES_MINUTES - client = get_client( - args.project_id, CLOUD_REGION, REGISTRY_ID, DEVICE_ID, - PRIVATE_KEY_FILE, ALGORITHM, CA_CERTS, - MQTT_BRIDGE_HOSTNAME, MQTT_BRIDGE_PORT) fps_counter = utils.avg_fps_counter(30) engines, titles = utils.make_engines(args.model, DetectionEngine) @@ -289,79 +129,54 @@ def render_gen(args): draw_overlay = True yield utils.input_image_size(engine) output = None - mqtt_cnt = 0 + message_count = 0 inference_time_window = collections.deque(maxlen=30) inference_time = 0.0 - while True: - d = {} - tensor, layout, command = (yield output) - inference_rate = next(fps_counter) - if draw_overlay: - start = time.monotonic() - objs = engine .DetectWithInputTensor(tensor, - threshold=args.threshold, - top_k=args.top_k) - inference_time_ = time.monotonic() - start - inference_time_window.append(inference_time_) - inference_time = sum(inference_time_window) / len(inference_time_window) - objs = [convert(obj, labels) for obj in objs] - if labels and filtered_labels: - objs = [obj for obj in objs if obj.label in filtered_labels] - objs = [obj for obj in objs if args.min_area <= obj.bbox.area() <= args.max_area] - if args.print: - print_results(inference_rate, objs) - for ind, obj in enumerate(objs): - tx = obj.label - o = {"name": tx, "points": ",".join([str(i) for i in obj.bbox_flat])} - d[ind] = o - - title = titles[engine] - output = overlay(title, objs, get_color, inference_time, inference_rate, layout) - else: - output = None - if command == 'o': - draw_overlay = not draw_overlay - elif command == 'n': - engine = next(engines) - - # Wait if backoff is required. - if should_backoff: - # If backoff time is too large, give up. - if minimum_backoff_time > MAXIMUM_BACKOFF_TIME: - print('Exceeded maximum backoff time. Giving up.') - break - - # Otherwise, wait and connect again. - delay = minimum_backoff_time + random.randint(0, 1000) / 1000.0 - print('Waiting for {} before reconnecting.'.format(delay)) - time.sleep(delay) - minimum_backoff_time *= 2 - client.connect(MQTT_BRIDGE_HOSTNAME, MQTT_BRIDGE_PORT) - payload = json.dumps(d) - - # [START iot_mqtt_jwt_refresh] - seconds_since_issue = (datetime.datetime.utcnow() - jwt_iat).seconds - if seconds_since_issue > 60 * jwt_exp_mins: - print('Refreshing token after {}s'.format(seconds_since_issue)) - jwt_iat = datetime.datetime.utcnow() - client = get_client( - PROJECT_ID, CLOUD_REGION, - REGISTRY_ID, DEVICE_ID, PRIVATE_KEY_FILE, - ALGORITHM, CA_CERTS, MQTT_BRIDGE_HOSTNAME, - MQTT_BRIDGE_PORT) - if mqtt_cnt > 0 and mqtt_cnt % 10 == 0: - client.loop() - print("-" * 20) - print(d) - print("-" * 20) - client.publish(mqtt_topic, payload, qos=1) - mqtt_cnt += 1 + with CloudIot(args.cloud_config) as cloud: + while True: + d = {} + tensor, layout, command = (yield output) + inference_rate = next(fps_counter) + if draw_overlay: + start = time.monotonic() + objs = engine.detect_with_input_tensor(tensor, + threshold=args.threshold, + top_k=args.top_k) + inference_time_ = time.monotonic() - start + inference_time_window.append(inference_time_) + inference_time = sum(inference_time_window) / len(inference_time_window) + objs = [convert(obj, labels) for obj in objs] + if labels and filtered_labels: + objs = [obj for obj in objs if obj.label in filtered_labels] + objs = [obj for obj in objs if args.min_area <= obj.bbox.area() <= args.max_area] + if args.print: + print_results(inference_rate, objs) + for ind, obj in enumerate(objs): + tx = obj.label + o = {"name": tx, "points": ",".join([str(i) for i in obj.bbox_flat])} + d[ind] = o + + title = titles[engine] + output = overlay(title, objs, get_color, inference_time, inference_rate, layout) + else: + output = None + if command == 'o': + draw_overlay = not draw_overlay + elif command == 'n': + engine = next(engines) + + payload = json.dumps(d) + + if message_count > 0 and message_count % 10 == 0: + print("-" * 20) + print(d) + print("-" * 20) + cloud.publish_message(payload) + message_count += 1 def add_render_gen_args(parser): - parser.add_argument('--project_id', - help='GCP Project ID', required=True) parser.add_argument('--model', help='.tflite model path', required=True) parser.add_argument('--labels', @@ -380,6 +195,8 @@ def add_render_gen_args(parser): help='Bounding box display color'), parser.add_argument('--print', default=False, action='store_true', help='Print inference results') + parser.add_argument('--cloud_config', + help='Cloud Config path', required=True) def main():