Skip to content

Commit

Permalink
conformance test
Browse files Browse the repository at this point in the history
  • Loading branch information
garymm committed Oct 1, 2023
1 parent f548517 commit 39f0a66
Show file tree
Hide file tree
Showing 11 changed files with 365 additions and 8 deletions.
22 changes: 17 additions & 5 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ go_library(
importpath = "github.com/Astera-org/mlflow-go",
visibility = ["//visibility:public"],
deps = [
":protos_go",
":protos_go_pregen",
"@com_github_google_uuid//:uuid",
"@in_gopkg_yaml_v3//:yaml_v3",
],
Expand Down Expand Up @@ -50,14 +50,26 @@ proto_library(
name = "protos",
srcs = glob(["protos/**/*.proto"]),
strip_import_prefix = "protos",
deps = [
"@com_google_protobuf//:descriptor_proto",
],
deps = ["@com_google_protobuf//:descriptor_proto"],
)

# This is used only in tools/update_protos.sh.
# protos_go_pregen is used in the actual build, since it's faster.
# We need the pregenerated .pb.go files to suport users of the `go` tool,
# so we may as well use them to speed up the build.
go_proto_library(
name = "protos_go",
importpath = "github.com/Astera-org/mlflow-go/protos",
protos = [":protos"],
visibility = ["__subpackages__"],
)

go_library(
name = "protos_go_pregen",
srcs = glob(["protos/*.pb.go"]),
importpath = "github.com/Astera-org/mlflow-go/protos",
deps = [
"@org_golang_google_protobuf//reflect/protoreflect",
"@org_golang_google_protobuf//runtime/protoimpl",
"@org_golang_google_protobuf//types/descriptorpb",
],
)
18 changes: 18 additions & 0 deletions MODULE.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,21 @@ use_repo(
"org_golang_google_protobuf",
"org_golang_x_tools",
)

bazel_dep(name = "rules_python", version = "0.25.0")

python = use_extension("@rules_python//python/extensions:python.bzl", "python")

python_version = "3.11"

python.toolchain(python_version = python_version)

use_repo(python, "python_3_11")

pip = use_extension("@rules_python//python/extensions:pip.bzl", "pip")
pip.parse(
hub_name = "pip",
python_version = python_version,
requirements_lock = "//:python_requirements_lock.txt",
)
use_repo(pip, "pip")
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ the unit tests, please run the manual tests.

The .proto files in the [protos](protos) directory are copied from the official mlflow repo.
Unfortunately not everybody uses Bazel, and so we have to check in the generated
protocol buffer code.
protocol buffer Go code.
To download the latest .proto files and regenerate the .pb.go files, run
[update_protos.sh](update_protos.sh).
[update_protos.sh](tools/update_protos.sh).

44 changes: 44 additions & 0 deletions conformance/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
load("@rules_python//python:defs.bzl", "py_binary", "py_test")
load("@rules_go//go:def.bzl", "go_binary")

go_binary(
name = "go",
srcs = ["main.go"],
gotags = ["bazel"],
deps = ["//:mlflow"],
)

py_binary(
name = "py",
srcs = ["main.py"],
main = "main.py",
deps = ["@pip//mlflow:pkg"],
)

# conformance_test starts MLFlow server, which runs gunicorn, which imports
# mlflow. This is a gunicorn executable that has access to mlflow.
py_binary(
name = "gunicorn",
srcs = ["gunicorn_exe.py"],
main = "gunicorn_exe.py",
deps = [
"@pip//gunicorn:pkg",
"@pip//mlflow:pkg",
],
)

py_test(
name = "conformance_test",
timeout = "short",
srcs = ["conformance_test.py"],
data = [
":go",
":gunicorn",
":py",
],
deps = [
"@pip//mlflow:pkg",
"@pip//requests:pkg",
"@rules_python//python/runfiles",
],
)
117 changes: 117 additions & 0 deletions conformance/conformance_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import logging
import multiprocessing
import os
import shutil
import socket
import subprocess
import tempfile
import time
import unittest

import mlflow
import mlflow.server
import python.runfiles.runfiles
import requests

runfiles = python.runfiles.runfiles.Create()

logging.basicConfig(level=logging.INFO)


class TC(unittest.TestCase):
def start_mlflow_server(self, root_dir) -> str:
# look for a free port
sock = socket.socket()
sock.bind(("", 0))
port = sock.getsockname()[1]
sock.close()

# mlflow server expects gunicorn on its PATH.
if not shutil.which("gunicorn"):
gunicorn_path = runfiles.Rlocation("_main/conformance/gunicorn")
if not os.path.exists(gunicorn_path):
raise RuntimeError(f"Could not find gunicorn binary at {gunicorn_path}")
os.environ["PATH"] = os.path.dirname(gunicorn_path) + ":" + os.environ["PATH"]
logging.info(
"Starting MLFlow server on port %d with file store %s",
port,
root_dir,
)
server_process = multiprocessing.Process(
target=mlflow.server._run_server,
kwargs={
"file_store_path": root_dir,
"registry_store_uri": None,
"default_artifact_root": None,
"serve_artifacts": False,
"artifacts_only": False,
"artifacts_destination": None,
"host": "localhost",
"port": port,
},
)
server_process.start()
# wait for server to start
server_uri = f"http://localhost:{port}"
up = False
for _ in range(10):
try:
up = requests.get(f"{server_uri}/api/2.0/mlflow/experiments/get?experiment_id=0")
break
except requests.exceptions.ConnectionError:
time.sleep(0.5)
self.assertTrue(up, f"server did not start: {server_uri}")
self.assertTrue(server_process.is_alive())
self.addCleanup(server_process.terminate)
return server_uri

def test_conformance(self):
for binary in (
"_main/conformance/go_/go",
"_main/conformance/py",
):
lang = os.path.basename(binary)
with self.subTest(lang=lang):
for start_server in (False, True):
root_dir = tempfile.mkdtemp()
# Python mlflow client fails if the directory
# exists but does not already contain the default experiment.
# Remove it so that it creates the default experiment rather than failing.
os.rmdir(root_dir)
with self.subTest(scheme="http" if start_server else "file"):
server_process = None
if start_server:
server_uri = self.start_mlflow_server(root_dir)
env = {"MLFLOW_TRACKING_URI": server_uri}
else:
env = {"MLFLOW_TRACKING_URI": root_dir}
subprocess.check_call(
(runfiles.Rlocation(binary),),
env=env,
)

client = mlflow.tracking.MlflowClient(
tracking_uri=env["MLFLOW_TRACKING_URI"]
)
exp = client.get_experiment_by_name("Default")
runs = client.search_runs([exp.experiment_id])
self.assertEqual(len(runs), 1)
run = runs[0]
metric_key = "metric0"
tag_key = "tag0"
param_key = "param0"

self.assertIn(metric_key, run.data.metrics)
self.assertEqual(run.data.metrics[metric_key], 10.0)
self.assertIn(tag_key, run.data.tags)
self.assertEqual(run.data.tags[tag_key], "value0")
self.assertIn(param_key, run.data.params)
self.assertEqual(run.data.params[param_key], "value0")

artifacts = client.list_artifacts(run.info.run_id)
self.assertEqual(len(artifacts), 1)
self.assertEqual(artifacts[0].path, "artifact0.txt")


if __name__ == "__main__":
unittest.main()
6 changes: 6 additions & 0 deletions conformance/gunicorn_exe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import sys

from gunicorn.app.wsgiapp import run

if __name__ == "__main__":
sys.exit(run())
44 changes: 44 additions & 0 deletions conformance/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
//go:build bazel

// ^^Regarding the above "go:build" directive:
// Go tool doesn't like mixing go and C++ in one directory,
// but this file is only really used by the conformance test
// which is built by bazel.

package main

import (
"log"
"os"
"path/filepath"

mlflow "github.com/Astera-org/mlflow-go"
)

func main() {
run, err := mlflow.ActiveRunFromEnv("", log.Default())
if err != nil {
panic(err)
}
for i := int64(0); i < 10; i++ {
run.LogMetric("metric0", float64(i+1), i)
}
run.SetTag("tag0", "value0")
run.LogParam("param0", "value0")

tempDir, err := os.MkdirTemp("", "*")
if err != nil {
panic(err)
}
artifactPath := filepath.Join(tempDir, "artifact0.txt")
if err = os.WriteFile(artifactPath, []byte("hello\n"), 0644); err != nil {
panic(err)
}
if err = run.LogArtifact(artifactPath, ""); err != nil {
panic(err)
}

if err = run.End(); err != nil {
panic(err)
}
}
21 changes: 21 additions & 0 deletions conformance/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import os
import tempfile

import mlflow


def main():
for i in range(10):
mlflow.log_metric("metric0", i + 1, step=i)

mlflow.set_tag("tag0", "value0")
mlflow.log_param("param0", "value0")
temp_dir = tempfile.mkdtemp()
artifact_path = os.path.join(temp_dir, "artifact0.txt")
with open(artifact_path, "wt") as f:
f.write("hello\n")
mlflow.log_artifact(artifact_path)


if __name__ == "__main__":
main()
57 changes: 57 additions & 0 deletions python_requirements_lock.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
alembic==1.12.0
blinker==1.6.2
certifi==2023.7.22
charset-normalizer==3.3.0
click==8.1.7
cloudpickle==2.2.1
contourpy==1.1.1
cycler==0.12.0
databricks-cli==0.17.8
docker==6.1.3
entrypoints==0.4
Flask==2.3.3
fonttools==4.43.0
gitdb==4.0.10
GitPython==3.1.37
gunicorn==21.2.0
idna==3.4
importlib-metadata==6.8.0
itsdangerous==2.1.2
Jinja2==3.1.2
joblib==1.3.2
kiwisolver==1.4.5
Mako==1.2.4
Markdown==3.4.4
MarkupSafe==2.1.3
matplotlib==3.8.0
mlflow==2.7.1
numpy==1.26.0
oauthlib==3.2.2
packaging==23.1
pandas==2.1.1
Pillow==10.0.1
pip==23.2.1
protobuf==4.24.3
pyarrow==13.0.0
PyJWT==2.8.0
pyparsing==3.1.1
python-dateutil==2.8.2
pytz==2023.3.post1
PyYAML==6.0.1
querystring-parser==1.2.4
requests==2.31.0
scikit-learn==1.3.1
scipy==1.11.3
setuptools==65.5.0
six==1.16.0
smmap==5.0.1
SQLAlchemy==2.0.21
sqlparse==0.4.4
tabulate==0.9.0
threadpoolctl==3.2.0
typing_extensions==4.8.0
tzdata==2023.3
urllib3==1.26.16
websocket-client==1.6.3
Werkzeug==3.0.0
zipp==3.17.0
32 changes: 32 additions & 0 deletions tools/add_python_dep.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#!/usr/bin/env bash

set -o errexit
set -o pipefail

if [[ "$#" -ne 1 ]]; then
echo "Usage: $0 <python package>"
exit 1
fi

cd "$(dirname $(dirname "$0"))"

VENV_DIR="$(mktemp -d)"

rm -rf "${VENV_DIR}"

# Use the Python interpeter that Bazel uses
bazel run @python_3_11//:python3 -- -m venv "${VENV_DIR}"

# Create a venv with existing dependencies
source "${VENV_DIR}/bin/activate"
pip install -U pip
pip install -r python_requirements_lock.txt

# add the new dependency
pip install "$@"

# freeze the new requirements
pip freeze --all > python_requirements_lock.txt
# bazel run //:gazelle_python_manifest.update

deactivate
Loading

0 comments on commit 39f0a66

Please sign in to comment.