Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add gradio_client; deploy.ssh; #44

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,14 @@ Our method compiles the following models to generate the set of marks:

We are standing on the shoulder of the giant GPT-4V ([playground](https://chat.openai.com/))!

### :rocket: Docker Quick Start

# Build the image
sudo nvidia-docker build -t som .

# Run the image
sudo docker run -d -p 6092:6092 --gpus all --name som-container som

### :rocket: Quick Start

* Install segmentation packages
Expand Down Expand Up @@ -90,7 +98,7 @@ And you will see this interface:
To deploy SoM to EC2 on AWS via Github Actions:

1. Fork this repository and clone your fork to your local machine.
2. Follow the instructions at the top of `deploy.py`.
2. Follow the instructions at the top of [`deploy.py`](https://github.com/microsoft/SoM/blob/main/deploy.py).

### :point_right: Comparing standard GPT-4V and its combination with SoM Prompting
![teaser_github](https://github.com/microsoft/SoM/assets/11957155/e4720105-b4b2-40c0-9303-2d8f1cb27d91)
Expand Down
58 changes: 52 additions & 6 deletions deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@

python deploy.py status

8. (optional) SSH into the server:

python deploy.py ssh

Troubleshooting Token Scope Error:

If you encounter an error similar to the following when pushing changes to
Expand Down Expand Up @@ -124,6 +128,9 @@ class Config(BaseSettings):
AWS_EC2_INSTANCE_TYPE: str = "g4dn.xlarge" # (T4 16GB $0.526/hr x86_64)
AWS_EC2_USER: str = "ubuntu"

# Note: changing this requires changing the hard-coded value in other files
PORT = 6092

class Config:
env_file = ".env"
env_file_encoding = 'utf-8'
Expand Down Expand Up @@ -246,7 +253,7 @@ def create_key_pair(key_name: str = config.AWS_EC2_KEY_NAME, key_path: str = con
logger.error(f"Error creating key pair: {e}")
return None

def get_or_create_security_group_id(ports: list[int] = [22, 6092]) -> str | None:
def get_or_create_security_group_id(ports: list[int] = [22, config.PORT]) -> str | None:
"""
Retrieves or creates a security group with the specified ports opened.

Expand Down Expand Up @@ -445,7 +452,7 @@ def configure_ec2_instance(
ssh_retries = 0
while ssh_retries < max_ssh_retries:
try:
ssh_client.connect(hostname=ec2_instance_ip, username='ubuntu', pkey=key)
ssh_client.connect(hostname=ec2_instance_ip, username=config.AWS_EC2_USER, pkey=key)
break # Successful SSH connection, break out of the loop
except Exception as e:
ssh_retries += 1
Expand Down Expand Up @@ -556,7 +563,7 @@ def get_gradio_server_url(ip_address: str) -> str:
Returns:
str: The Gradio server URL
"""
url = f"http://{ip_address}:6092" # TODO: make port configurable
url = f"http://{ip_address}:{config.PORT}"
return url

def git_push_set_upstream(branch_name: str):
Expand Down Expand Up @@ -702,19 +709,58 @@ def stop(
@staticmethod
def status() -> None:
"""
Lists all EC2 instances tagged with the project name.
Lists all EC2 instances tagged with the project name, along with their HTTP URLs.

Returns:
None
"""
ec2 = boto3.resource('ec2')

instances = ec2.instances.filter(
Filters=[{'Name': 'tag:Name', 'Values': [config.PROJECT_NAME]}]
)

for instance in instances:
logger.info(f"Instance ID: {instance.id}, State: {instance.state['Name']}")
public_ip = instance.public_ip_address
if public_ip:
http_url = f"http://{public_ip}:{config.PORT}"
logger.info(f"Instance ID: {instance.id}, State: {instance.state['Name']}, HTTP URL: {http_url}")
else:
logger.info(f"Instance ID: {instance.id}, State: {instance.state['Name']}, HTTP URL: Not available (no public IP)")

@staticmethod
def ssh(project_name: str = config.PROJECT_NAME) -> None:
"""
Establishes an SSH connection to the EC2 instance associated with the specified project name using subprocess.

Args:
project_name (str): The project name used to tag the instance. Defaults to config.PROJECT_NAME.

Returns:
None
"""
ec2 = boto3.resource('ec2')
instances = ec2.instances.filter(
Filters=[
{'Name': 'tag:Name', 'Values': [project_name]},
{'Name': 'instance-state-name', 'Values': ['running']}
]
)

for instance in instances:
logger.info(f"Attempting to SSH into instance: ID - {instance.id}, IP - {instance.public_ip_address}")

# Build the SSH command
ssh_command = [
"ssh",
"-i", config.AWS_EC2_KEY_PATH,
f"{config.AWS_EC2_USER}@{instance.public_ip_address}"
]

# Start an interactive shell session
try:
subprocess.run(ssh_command, check=True)
except subprocess.CalledProcessError as e:
logger.error(f"SSH connection failed: {e}")

if __name__ == "__main__":
fire.Fire(Deploy)
1 change: 1 addition & 0 deletions deploy_requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
boto3==1.34.18
fire==0.5.0
gitpython==3.1.41
gradio_client==0.17.0
jinja2==3.1.3
loguru==0.7.2
paramiko==3.4.0
Expand Down