Skip to content

Commit

Permalink
Merge pull request #1 from mbrooksx/master
Browse files Browse the repository at this point in the history
Use Coral Cloud IoT Core Library
  • Loading branch information
hayatoy authored Jan 30, 2020
2 parents c0c42fb + 05a9295 commit fb721e6
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 234 deletions.
18 changes: 12 additions & 6 deletions edge/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,13 @@ Also make sure if you can [run a model with camera](https://coral.withgoogle.com

1. Install required libraries
```
$ pip3 install pyjwt==1.7.1, paho-mqtt, imutils
$ sudo apt-get install python3-cryptography
echo "deb https://packages.cloud.google.com/apt coral-cloud-stable main" | sudo tee /etc/apt/sources.list.d/coral-cloud.list
curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo apt-key add -
sudo apt update
sudo apt install python3-coral-cloudiot
```

2. Clone this repository
Expand All @@ -31,8 +36,10 @@ Each device must be registered to IoT Core in order to publish data and receive
1. [Enable IoT Core and Pub/Sub](https://console.cloud.google.com/flows/enableapi?apiid=cloudiot.googleapis.com,pubsub)
2. [Create a device registry](https://cloud.google.com/iot/docs/quickstart#create_a_device_registry)
Use `demo1` for Device name, and `demo-topic` for Topic name.
3. [Generate a device key pair](https://cloud.google.com/iot/docs/quickstart#generate_a_device_key_pair)
4. [Add a device to the registry](https://cloud.google.com/iot/docs/quickstart#add_a_device_to_the_registry)
3. [Add a device to the registry](https://cloud.google.com/iot/docs/quickstart#add_a_device_to_the_registry)
To get the ES256 key of the HW crypto run `python3 /usr/lib/python3/dist-packages/coral/cloudiot/ecc608_pubkey.py`
4. Configure the cloud config
Edit the cloud_config.ini in the edge folder to ProjectID, RegistryID, and DeviceID set in step 2. Enable Cloud IoT core by setting Enabled = true.
5. Install root certificate
```
$ wget https://pki.goog/roots.pem
Expand All @@ -43,9 +50,8 @@ Run the following commands, then open browser.

```
export DEMO_FILES="$HOME/demo_files"
export PROJECTID="YOUR GCP PROJECT ID"
python3 detect_cloudiot.py \
--project_id ${PROJECTID} \
--cloud_config cloud_config.ini \
--model ${DEMO_FILES}/mobilenet_ssd_v1_coco_quant_postprocess_edgetpu.tflite \
--labels ${DEMO_FILES}/coco_labels.txt \
--threshold 0.4 \
Expand Down
14 changes: 14 additions & 0 deletions edge/cloud_config.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
[DEFAULT]
Enabled = false
ProjectID = my-project
CloudRegion = us-central1
RegistryID = demo-registry
DeviceID = demo1
# CA Certs should be pulled from https://pki.goog/roots.pem
CACerts = roots.pem
MQTTBridgeHostName = mqtt.googleapis.com
MQTTBridgePort = 8883
# MessageType is expected to always be event.
MessageType = event
# RSA Cert is not required unless SW crypto is used.
RSACertFile =
273 changes: 45 additions & 228 deletions edge/detect_cloudiot.py
Original file line number Diff line number Diff line change
@@ -1,166 +1,17 @@
"""A demo which runs object detection on camera frames.
"""
# [START iot_mqtt_includes]
import datetime
import os
import ssl
import time
import jwt
import paho.mqtt.client as mqtt
# [END iot_mqtt_includes]

import argparse
import collections
import colorsys
import itertools
import time
from coral.cloudiot.core import CloudIot
from edgetpu.detection.engine import DetectionEngine
from edgetpuvision import svg
from edgetpuvision import utils
from edgetpuvision.apps import run_app, run_server

# The initial backoff time after a disconnection occurs, in seconds.
minimum_backoff_time = 1

# The maximum backoff time before giving up, in seconds.
MAXIMUM_BACKOFF_TIME = 32

# Whether to wait with exponential backoff before publishing.
should_backoff = False

# MQTT settings, change if needed
ALGORITHM = 'RS256'
MQTT_BRIDGE_HOSTNAME = 'mqtt.googleapis.com'
MQTT_BRIDGE_PORT = 8883
PRIVATE_KEY_FILE = '/home/mendel/rsa_private.pem'
CA_CERTS = '/home/mendel/roots.pem'
JWT_EXPIRES_MINUTES = 20
MESSAGE_TYPE = 'event'

# Cloud IoT settings
CLOUD_REGION = 'us-central1'
DEVICE_ID = 'demo1'
REGISTRY_ID = 'demo-registry'


# [START iot_mqtt_jwt]
def create_jwt(project_id, private_key_file, algorithm):
"""Creates a JWT (https://jwt.io) to establish an MQTT connection.
Args:
project_id: The cloud project ID this device belongs to
private_key_file: A path to a file containing either an RSA256 or
ES256 private key.
algorithm: The encryption algorithm to use. Either 'RS256' or 'ES256'
Returns:
An MQTT generated from the given project_id and private key, which
expires in 20 minutes. After 20 minutes, your client will be
disconnected, and a new JWT will have to be generated.
Raises:
ValueError: If the private_key_file does not contain a known key.
"""

token = {
# The time that the token was issued at
'iat': datetime.datetime.utcnow(),
# The time the token expires.
'exp': datetime.datetime.utcnow() + datetime.timedelta(minutes=60),
# The audience field should always be set to the GCP project id.
'aud': project_id
}

# Read the private key file.
with open(private_key_file, 'r') as f:
private_key = f.read()

print('Creating JWT using {} from private key file {}'.format(
algorithm, private_key_file))

return jwt.encode(token, private_key, algorithm=algorithm)
# [END iot_mqtt_jwt]


# [START iot_mqtt_config]
def error_str(rc):
"""Convert a Paho error to a human readable string."""
return '{}: {}'.format(rc, mqtt.error_string(rc))


def on_connect(unused_client, unused_userdata, unused_flags, rc):
"""Callback for when a device connects."""
print('on_connect', mqtt.connack_string(rc))

# After a successful connect, reset backoff time and stop backing off.
global should_backoff
global minimum_backoff_time
should_backoff = False
minimum_backoff_time = 1


def on_disconnect(unused_client, unused_userdata, rc):
"""Paho callback for when a device disconnects."""
print('on_disconnect', error_str(rc))

# Since a disconnect occurred, the next loop iteration will wait with
# exponential backoff.
global should_backoff
should_backoff = True


def on_publish(unused_client, unused_userdata, unused_mid):
"""Paho callback when a message is sent to the broker."""
print('on_publish')


def on_message(unused_client, unused_userdata, message):
"""Callback when the device receives a message on a subscription."""
payload = str(message.payload)
print('Received message \'{}\' on topic \'{}\' with Qos {}'.format(
payload, message.topic, str(message.qos)))


def get_client(
project_id, cloud_region, registry_id, device_id, private_key_file,
algorithm, ca_certs, mqtt_bridge_hostname, mqtt_bridge_port):
"""Create our MQTT client. The client_id is a unique string that identifies
this device. For Google Cloud IoT Core, it must be in the format below."""
client = mqtt.Client(
client_id=('projects/{}/locations/{}/registries/{}/devices/{}'
.format(
project_id,
cloud_region,
registry_id,
device_id)))

# With Google Cloud IoT Core, the username field is ignored, and the
# password field is used to transmit a JWT to authorize the device.
client.username_pw_set(username='unused',
password=create_jwt(
project_id, private_key_file, algorithm))

# Enable SSL/TLS support.
client.tls_set(ca_certs=ca_certs, tls_version=ssl.PROTOCOL_TLSv1_2)

# Register message callbacks. https://eclipse.org/paho/clients/python/docs/
# describes additional callbacks that Paho supports. In this example, the
# callbacks just print to standard out.
client.on_connect = on_connect
client.on_publish = on_publish
client.on_disconnect = on_disconnect
client.on_message = on_message

# Connect to the Google MQTT bridge.
client.connect(mqtt_bridge_hostname, mqtt_bridge_port)

# This is the topic that the device will receive configuration updates on.
mqtt_config_topic = '/devices/{}/config'.format(device_id)

# Subscribe to the config topic.
client.subscribe(mqtt_config_topic, qos=1)

return client
# [END iot_mqtt_config]


CSS_STYLES = str(svg.CssStyle({'.back': svg.Style(fill='black',
stroke='black',
stroke_width='0.5em'),
Expand Down Expand Up @@ -266,17 +117,6 @@ def render_gen(args):
global minimum_backoff_time
import json
import random
# Publish to the events or state topic based on the flag.
sub_topic = 'events' if MESSAGE_TYPE == 'event' else 'state'

mqtt_topic = '/devices/{}/{}'.format(DEVICE_ID, sub_topic)

jwt_iat = datetime.datetime.utcnow()
jwt_exp_mins = JWT_EXPIRES_MINUTES
client = get_client(
args.project_id, CLOUD_REGION, REGISTRY_ID, DEVICE_ID,
PRIVATE_KEY_FILE, ALGORITHM, CA_CERTS,
MQTT_BRIDGE_HOSTNAME, MQTT_BRIDGE_PORT)

fps_counter = utils.avg_fps_counter(30)
engines, titles = utils.make_engines(args.model, DetectionEngine)
Expand All @@ -289,79 +129,54 @@ def render_gen(args):
draw_overlay = True
yield utils.input_image_size(engine)
output = None
mqtt_cnt = 0
message_count = 0
inference_time_window = collections.deque(maxlen=30)
inference_time = 0.0

while True:
d = {}
tensor, layout, command = (yield output)
inference_rate = next(fps_counter)
if draw_overlay:
start = time.monotonic()
objs = engine .DetectWithInputTensor(tensor,
threshold=args.threshold,
top_k=args.top_k)
inference_time_ = time.monotonic() - start
inference_time_window.append(inference_time_)
inference_time = sum(inference_time_window) / len(inference_time_window)
objs = [convert(obj, labels) for obj in objs]
if labels and filtered_labels:
objs = [obj for obj in objs if obj.label in filtered_labels]
objs = [obj for obj in objs if args.min_area <= obj.bbox.area() <= args.max_area]
if args.print:
print_results(inference_rate, objs)
for ind, obj in enumerate(objs):
tx = obj.label
o = {"name": tx, "points": ",".join([str(i) for i in obj.bbox_flat])}
d[ind] = o

title = titles[engine]
output = overlay(title, objs, get_color, inference_time, inference_rate, layout)
else:
output = None
if command == 'o':
draw_overlay = not draw_overlay
elif command == 'n':
engine = next(engines)

# Wait if backoff is required.
if should_backoff:
# If backoff time is too large, give up.
if minimum_backoff_time > MAXIMUM_BACKOFF_TIME:
print('Exceeded maximum backoff time. Giving up.')
break

# Otherwise, wait and connect again.
delay = minimum_backoff_time + random.randint(0, 1000) / 1000.0
print('Waiting for {} before reconnecting.'.format(delay))
time.sleep(delay)
minimum_backoff_time *= 2
client.connect(MQTT_BRIDGE_HOSTNAME, MQTT_BRIDGE_PORT)
payload = json.dumps(d)

# [START iot_mqtt_jwt_refresh]
seconds_since_issue = (datetime.datetime.utcnow() - jwt_iat).seconds
if seconds_since_issue > 60 * jwt_exp_mins:
print('Refreshing token after {}s'.format(seconds_since_issue))
jwt_iat = datetime.datetime.utcnow()
client = get_client(
PROJECT_ID, CLOUD_REGION,
REGISTRY_ID, DEVICE_ID, PRIVATE_KEY_FILE,
ALGORITHM, CA_CERTS, MQTT_BRIDGE_HOSTNAME,
MQTT_BRIDGE_PORT)
if mqtt_cnt > 0 and mqtt_cnt % 10 == 0:
client.loop()
print("-" * 20)
print(d)
print("-" * 20)
client.publish(mqtt_topic, payload, qos=1)
mqtt_cnt += 1
with CloudIot(args.cloud_config) as cloud:
while True:
d = {}
tensor, layout, command = (yield output)
inference_rate = next(fps_counter)
if draw_overlay:
start = time.monotonic()
objs = engine.detect_with_input_tensor(tensor,
threshold=args.threshold,
top_k=args.top_k)
inference_time_ = time.monotonic() - start
inference_time_window.append(inference_time_)
inference_time = sum(inference_time_window) / len(inference_time_window)
objs = [convert(obj, labels) for obj in objs]
if labels and filtered_labels:
objs = [obj for obj in objs if obj.label in filtered_labels]
objs = [obj for obj in objs if args.min_area <= obj.bbox.area() <= args.max_area]
if args.print:
print_results(inference_rate, objs)
for ind, obj in enumerate(objs):
tx = obj.label
o = {"name": tx, "points": ",".join([str(i) for i in obj.bbox_flat])}
d[ind] = o

title = titles[engine]
output = overlay(title, objs, get_color, inference_time, inference_rate, layout)
else:
output = None
if command == 'o':
draw_overlay = not draw_overlay
elif command == 'n':
engine = next(engines)

payload = json.dumps(d)

if message_count > 0 and message_count % 10 == 0:
print("-" * 20)
print(d)
print("-" * 20)
cloud.publish_message(payload)
message_count += 1


def add_render_gen_args(parser):
parser.add_argument('--project_id',
help='GCP Project ID', required=True)
parser.add_argument('--model',
help='.tflite model path', required=True)
parser.add_argument('--labels',
Expand All @@ -380,6 +195,8 @@ def add_render_gen_args(parser):
help='Bounding box display color'),
parser.add_argument('--print', default=False, action='store_true',
help='Print inference results')
parser.add_argument('--cloud_config',
help='Cloud Config path', required=True)


def main():
Expand Down

0 comments on commit fb721e6

Please sign in to comment.