-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
11 changed files
with
365 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
load("@rules_python//python:defs.bzl", "py_binary", "py_test") | ||
load("@rules_go//go:def.bzl", "go_binary") | ||
|
||
go_binary( | ||
name = "go", | ||
srcs = ["main.go"], | ||
gotags = ["bazel"], | ||
deps = ["//:mlflow"], | ||
) | ||
|
||
py_binary( | ||
name = "py", | ||
srcs = ["main.py"], | ||
main = "main.py", | ||
deps = ["@pip//mlflow:pkg"], | ||
) | ||
|
||
# conformance_test starts MLFlow server, which runs gunicorn, which imports | ||
# mlflow. This is a gunicorn executable that has access to mlflow. | ||
py_binary( | ||
name = "gunicorn", | ||
srcs = ["gunicorn_exe.py"], | ||
main = "gunicorn_exe.py", | ||
deps = [ | ||
"@pip//gunicorn:pkg", | ||
"@pip//mlflow:pkg", | ||
], | ||
) | ||
|
||
py_test( | ||
name = "conformance_test", | ||
timeout = "short", | ||
srcs = ["conformance_test.py"], | ||
data = [ | ||
":go", | ||
":gunicorn", | ||
":py", | ||
], | ||
deps = [ | ||
"@pip//mlflow:pkg", | ||
"@pip//requests:pkg", | ||
"@rules_python//python/runfiles", | ||
], | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
import logging | ||
import multiprocessing | ||
import os | ||
import shutil | ||
import socket | ||
import subprocess | ||
import tempfile | ||
import time | ||
import unittest | ||
|
||
import mlflow | ||
import mlflow.server | ||
import python.runfiles.runfiles | ||
import requests | ||
|
||
runfiles = python.runfiles.runfiles.Create() | ||
|
||
logging.basicConfig(level=logging.INFO) | ||
|
||
|
||
class TC(unittest.TestCase): | ||
def start_mlflow_server(self, root_dir) -> str: | ||
# look for a free port | ||
sock = socket.socket() | ||
sock.bind(("", 0)) | ||
port = sock.getsockname()[1] | ||
sock.close() | ||
|
||
# mlflow server expects gunicorn on its PATH. | ||
if not shutil.which("gunicorn"): | ||
gunicorn_path = runfiles.Rlocation("_main/conformance/gunicorn") | ||
if not os.path.exists(gunicorn_path): | ||
raise RuntimeError(f"Could not find gunicorn binary at {gunicorn_path}") | ||
os.environ["PATH"] = os.path.dirname(gunicorn_path) + ":" + os.environ["PATH"] | ||
logging.info( | ||
"Starting MLFlow server on port %d with file store %s", | ||
port, | ||
root_dir, | ||
) | ||
server_process = multiprocessing.Process( | ||
target=mlflow.server._run_server, | ||
kwargs={ | ||
"file_store_path": root_dir, | ||
"registry_store_uri": None, | ||
"default_artifact_root": None, | ||
"serve_artifacts": False, | ||
"artifacts_only": False, | ||
"artifacts_destination": None, | ||
"host": "localhost", | ||
"port": port, | ||
}, | ||
) | ||
server_process.start() | ||
# wait for server to start | ||
server_uri = f"http://localhost:{port}" | ||
up = False | ||
for _ in range(10): | ||
try: | ||
up = requests.get(f"{server_uri}/api/2.0/mlflow/experiments/get?experiment_id=0") | ||
break | ||
except requests.exceptions.ConnectionError: | ||
time.sleep(0.5) | ||
self.assertTrue(up, f"server did not start: {server_uri}") | ||
self.assertTrue(server_process.is_alive()) | ||
self.addCleanup(server_process.terminate) | ||
return server_uri | ||
|
||
def test_conformance(self): | ||
for binary in ( | ||
"_main/conformance/go_/go", | ||
"_main/conformance/py", | ||
): | ||
lang = os.path.basename(binary) | ||
with self.subTest(lang=lang): | ||
for start_server in (False, True): | ||
root_dir = tempfile.mkdtemp() | ||
# Python mlflow client fails if the directory | ||
# exists but does not already contain the default experiment. | ||
# Remove it so that it creates the default experiment rather than failing. | ||
os.rmdir(root_dir) | ||
with self.subTest(scheme="http" if start_server else "file"): | ||
server_process = None | ||
if start_server: | ||
server_uri = self.start_mlflow_server(root_dir) | ||
env = {"MLFLOW_TRACKING_URI": server_uri} | ||
else: | ||
env = {"MLFLOW_TRACKING_URI": root_dir} | ||
subprocess.check_call( | ||
(runfiles.Rlocation(binary),), | ||
env=env, | ||
) | ||
|
||
client = mlflow.tracking.MlflowClient( | ||
tracking_uri=env["MLFLOW_TRACKING_URI"] | ||
) | ||
exp = client.get_experiment_by_name("Default") | ||
runs = client.search_runs([exp.experiment_id]) | ||
self.assertEqual(len(runs), 1) | ||
run = runs[0] | ||
metric_key = "metric0" | ||
tag_key = "tag0" | ||
param_key = "param0" | ||
|
||
self.assertIn(metric_key, run.data.metrics) | ||
self.assertEqual(run.data.metrics[metric_key], 10.0) | ||
self.assertIn(tag_key, run.data.tags) | ||
self.assertEqual(run.data.tags[tag_key], "value0") | ||
self.assertIn(param_key, run.data.params) | ||
self.assertEqual(run.data.params[param_key], "value0") | ||
|
||
artifacts = client.list_artifacts(run.info.run_id) | ||
self.assertEqual(len(artifacts), 1) | ||
self.assertEqual(artifacts[0].path, "artifact0.txt") | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
import sys | ||
|
||
from gunicorn.app.wsgiapp import run | ||
|
||
if __name__ == "__main__": | ||
sys.exit(run()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
//go:build bazel | ||
|
||
// ^^Regarding the above "go:build" directive: | ||
// Go tool doesn't like mixing go and C++ in one directory, | ||
// but this file is only really used by the conformance test | ||
// which is built by bazel. | ||
|
||
package main | ||
|
||
import ( | ||
"log" | ||
"os" | ||
"path/filepath" | ||
|
||
mlflow "github.com/Astera-org/mlflow-go" | ||
) | ||
|
||
func main() { | ||
run, err := mlflow.ActiveRunFromEnv("", log.Default()) | ||
if err != nil { | ||
panic(err) | ||
} | ||
for i := int64(0); i < 10; i++ { | ||
run.LogMetric("metric0", float64(i+1), i) | ||
} | ||
run.SetTag("tag0", "value0") | ||
run.LogParam("param0", "value0") | ||
|
||
tempDir, err := os.MkdirTemp("", "*") | ||
if err != nil { | ||
panic(err) | ||
} | ||
artifactPath := filepath.Join(tempDir, "artifact0.txt") | ||
if err = os.WriteFile(artifactPath, []byte("hello\n"), 0644); err != nil { | ||
panic(err) | ||
} | ||
if err = run.LogArtifact(artifactPath, ""); err != nil { | ||
panic(err) | ||
} | ||
|
||
if err = run.End(); err != nil { | ||
panic(err) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
import os | ||
import tempfile | ||
|
||
import mlflow | ||
|
||
|
||
def main(): | ||
for i in range(10): | ||
mlflow.log_metric("metric0", i + 1, step=i) | ||
|
||
mlflow.set_tag("tag0", "value0") | ||
mlflow.log_param("param0", "value0") | ||
temp_dir = tempfile.mkdtemp() | ||
artifact_path = os.path.join(temp_dir, "artifact0.txt") | ||
with open(artifact_path, "wt") as f: | ||
f.write("hello\n") | ||
mlflow.log_artifact(artifact_path) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
alembic==1.12.0 | ||
blinker==1.6.2 | ||
certifi==2023.7.22 | ||
charset-normalizer==3.3.0 | ||
click==8.1.7 | ||
cloudpickle==2.2.1 | ||
contourpy==1.1.1 | ||
cycler==0.12.0 | ||
databricks-cli==0.17.8 | ||
docker==6.1.3 | ||
entrypoints==0.4 | ||
Flask==2.3.3 | ||
fonttools==4.43.0 | ||
gitdb==4.0.10 | ||
GitPython==3.1.37 | ||
gunicorn==21.2.0 | ||
idna==3.4 | ||
importlib-metadata==6.8.0 | ||
itsdangerous==2.1.2 | ||
Jinja2==3.1.2 | ||
joblib==1.3.2 | ||
kiwisolver==1.4.5 | ||
Mako==1.2.4 | ||
Markdown==3.4.4 | ||
MarkupSafe==2.1.3 | ||
matplotlib==3.8.0 | ||
mlflow==2.7.1 | ||
numpy==1.26.0 | ||
oauthlib==3.2.2 | ||
packaging==23.1 | ||
pandas==2.1.1 | ||
Pillow==10.0.1 | ||
pip==23.2.1 | ||
protobuf==4.24.3 | ||
pyarrow==13.0.0 | ||
PyJWT==2.8.0 | ||
pyparsing==3.1.1 | ||
python-dateutil==2.8.2 | ||
pytz==2023.3.post1 | ||
PyYAML==6.0.1 | ||
querystring-parser==1.2.4 | ||
requests==2.31.0 | ||
scikit-learn==1.3.1 | ||
scipy==1.11.3 | ||
setuptools==65.5.0 | ||
six==1.16.0 | ||
smmap==5.0.1 | ||
SQLAlchemy==2.0.21 | ||
sqlparse==0.4.4 | ||
tabulate==0.9.0 | ||
threadpoolctl==3.2.0 | ||
typing_extensions==4.8.0 | ||
tzdata==2023.3 | ||
urllib3==1.26.16 | ||
websocket-client==1.6.3 | ||
Werkzeug==3.0.0 | ||
zipp==3.17.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
#!/usr/bin/env bash | ||
|
||
set -o errexit | ||
set -o pipefail | ||
|
||
if [[ "$#" -ne 1 ]]; then | ||
echo "Usage: $0 <python package>" | ||
exit 1 | ||
fi | ||
|
||
cd "$(dirname $(dirname "$0"))" | ||
|
||
VENV_DIR="$(mktemp -d)" | ||
|
||
rm -rf "${VENV_DIR}" | ||
|
||
# Use the Python interpeter that Bazel uses | ||
bazel run @python_3_11//:python3 -- -m venv "${VENV_DIR}" | ||
|
||
# Create a venv with existing dependencies | ||
source "${VENV_DIR}/bin/activate" | ||
pip install -U pip | ||
pip install -r python_requirements_lock.txt | ||
|
||
# add the new dependency | ||
pip install "$@" | ||
|
||
# freeze the new requirements | ||
pip freeze --all > python_requirements_lock.txt | ||
# bazel run //:gazelle_python_manifest.update | ||
|
||
deactivate |
Oops, something went wrong.