Skip to content

Commit

Permalink
After creating venv, unconditionally find new binary. (#83)
Browse files Browse the repository at this point in the history
* After creating venv, unconditionally find new binary.

This isn't perfect, it can still find a non-venv binary, but it should
be closer to right.

* Update CLI rendering for venv events to key on venv path

* Refactor which into utils, and find_runtime tweaks

* Debug logging and test logic tweaks

* Drop asserts in favor of best effort chaining

* Drop debug logging

---------

Co-authored-by: Amethyst Reese <[email protected]>
  • Loading branch information
thatch and amyreese authored May 29, 2024
1 parent 83e88ac commit c985a75
Show file tree
Hide file tree
Showing 8 changed files with 88 additions and 79 deletions.
9 changes: 6 additions & 3 deletions thx/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import logging
from collections import defaultdict
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, cast, Dict, List, Optional

from rich.console import Group
Expand All @@ -13,6 +14,7 @@

from .types import (
Context,
ContextEvent,
Event,
Fail,
Job,
Expand All @@ -30,7 +32,7 @@

@dataclass
class RichRenderer:
venvs: Dict[Context, Event] = field(default_factory=dict)
venvs: Dict[Path, ContextEvent] = field(default_factory=dict)
latest: Dict[Job, Dict[Context, Dict[Step, Event]]] = field(
default_factory=lambda: defaultdict(lambda: defaultdict(dict))
)
Expand Down Expand Up @@ -66,7 +68,7 @@ def __call__(self, event: Optional[Event]) -> None:
return

if isinstance(event, (VenvCreate, VenvError, VenvReady)):
venvs[event.context] = event
venvs[event.context.venv] = event
elif isinstance(event, JobEvent):
step = event.step
job = step.job
Expand All @@ -76,7 +78,8 @@ def __call__(self, event: Optional[Event]) -> None:

if venvs and not all(isinstance(v, VenvReady) for v in venvs.values()):
tree = Tree("Preparing virtualenvs...")
for context, event in venvs.items():
for _, event in venvs.items():
context = event.context
if isinstance(event, VenvReady):
text = Text(f"{context.python_version}> done", style="green")
elif isinstance(event, VenvError):
Expand Down
26 changes: 14 additions & 12 deletions thx/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@

from aioitertools.asyncio import as_generated

from .runner import check_command, which

from .runner import check_command
from .types import (
CommandError,
Config,
Expand All @@ -27,8 +26,7 @@
VenvReady,
Version,
)

from .utils import timed, version_match
from .utils import timed, venv_bin_path, version_match, which

LOG = logging.getLogger(__name__)
PYTHON_VERSION_RE = re.compile(r"Python (\d+\.\d+[a-zA-Z0-9-_.]+)\+?")
Expand Down Expand Up @@ -81,11 +79,12 @@ def find_runtime(
version: Version, venv: Optional[Path] = None
) -> Tuple[Optional[Path], Optional[Version]]:
if venv and venv.is_dir():
bin_dir = venv / "bin"
if bin_dir.is_dir():
binary_path_str = shutil.which("python", path=f"{bin_dir.as_posix()}")
if binary_path_str:
return Path(binary_path_str), None
bin_dir = venv_bin_path(venv)
binary_path_str = shutil.which("python", path=bin_dir.as_posix())
if binary_path_str:
binary_path = Path(binary_path_str)
binary_version = runtime_version(binary_path)
return binary_path, binary_version

# TODO: better way to find specific micro/pre/post versions?
binary_names = [
Expand Down Expand Up @@ -198,9 +197,6 @@ async def prepare_virtualenv(context: Context, config: Config) -> AsyncIterator[
import venv

venv.create(context.venv, prompt=prompt, with_pip=True)
new_python_path, _ = find_runtime(context.python_version, context.venv)
assert new_python_path is not None
context.python_path = new_python_path

else:
await check_command(
Expand All @@ -214,6 +210,12 @@ async def prepare_virtualenv(context: Context, config: Config) -> AsyncIterator[
]
)

new_python_path, new_python_version = find_runtime(
context.python_version, context.venv
)
context.python_path = new_python_path or context.python_path
context.python_version = new_python_version or context.python_version

# upgrade pip
yield VenvCreate(context, message="upgrading pip")
await check_command(
Expand Down
24 changes: 2 additions & 22 deletions thx/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,9 @@
import asyncio
import logging
import os
import platform
import shlex
import shutil
from asyncio.subprocess import PIPE
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Sequence

from .types import (
Expand All @@ -22,28 +19,11 @@
Step,
StrPath,
)
from .utils import venv_bin_path, which

LOG = logging.getLogger(__name__)


def venv_bin_path(context: Context) -> Path:
if platform.system() == "Windows":
bin_path = context.venv / "Scripts"
else:
bin_path = context.venv / "bin"
return bin_path


def which(name: str, context: Context) -> str:
bin_path = venv_bin_path(context).as_posix()
binary = shutil.which(name, path=bin_path)
if binary is None:
binary = shutil.which(name)
if binary is None:
return name
return binary


def render_command(run: str, context: Context, config: Config) -> Sequence[str]:
run = run.format(**config.values, python_version=context.python_version)
cmd = shlex.split(run)
Expand All @@ -58,7 +38,7 @@ async def run_command(
new_env: Optional[Dict[str, str]] = None
if context:
new_env = os.environ.copy()
new_env["PATH"] = f"{venv_bin_path(context)}{os.pathsep}{new_env['PATH']}"
new_env["PATH"] = f"{venv_bin_path(context.venv)}{os.pathsep}{new_env['PATH']}"
proc = await asyncio.create_subprocess_exec(
*cmd, stdout=PIPE, stderr=PIPE, env=new_env
)
Expand Down
4 changes: 2 additions & 2 deletions thx/tests/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ def test_render_venv(self, live_mock: MagicMock) -> None:
live_mock.return_value.reset_mock()
render(event)

self.assertIn(ctx, render.venvs)
self.assertEqual(event, render.venvs[ctx])
self.assertIn(ctx.venv, render.venvs)
self.assertEqual(event, render.venvs[ctx.venv])
live_mock.return_value.update.assert_called_once()

def test_render_job(self, live_mock: MagicMock) -> None:
Expand Down
9 changes: 5 additions & 4 deletions thx/tests/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
VenvReady,
Version,
)
from ..utils import venv_bin_path

TEST_VERSIONS = [
Version(v)
Expand Down Expand Up @@ -228,21 +229,21 @@ def test_find_runtime_venv(self, runtime_mock: Mock, which_mock: Mock) -> None:

with self.subTest(version):
venv = context.venv_path(config, version)
(venv / "bin").mkdir(parents=True, exist_ok=True)
bin_dir = venv_bin_path(venv)
bin_dir.mkdir(parents=True, exist_ok=True)

expected = venv / "bin" / "python"
expected = bin_dir / "python"
result, _ = context.find_runtime(version, venv)
self.assertEqual(expected, result)

which_mock.assert_has_calls(
[
call(
"python",
path=(venv / "bin").as_posix(),
path=bin_dir.as_posix(),
),
]
)
runtime_mock.assert_not_called()

@patch("thx.context.find_runtime")
def test_resolve_contexts_no_config(self, runtime_mock: Mock) -> None:
Expand Down
40 changes: 4 additions & 36 deletions thx/tests/runner.py
Original file line number Diff line number Diff line change
@@ -1,51 +1,19 @@
# Copyright 2022 Amethyst Reese
# Licensed under the MIT License

import platform
import sys
from asyncio.subprocess import PIPE
from pathlib import Path
from unittest import skipIf, TestCase
from unittest.mock import ANY, call, Mock, patch
from unittest.mock import ANY, Mock, patch

from .. import runner
from ..types import CommandError, CommandResult, Config, Context, Job, Result, Version
from ..utils import venv_bin_path
from .helper import async_test


class RunnerTest(TestCase):
@patch("thx.runner.shutil.which")
def test_which(self, which_mock: Mock) -> None:
context = Context(Version("3.10"), Path(), Path("/fake/venv"))
fake_venv_bin = (
"/fake/venv/Scripts" if platform.system() == "Windows" else "/fake/venv/bin"
)
with self.subTest("found"):
which_mock.side_effect = lambda b, path: f"/usr/bin/{b}"
self.assertEqual("/usr/bin/frobfrob", runner.which("frobfrob", context))
which_mock.assert_has_calls([call("frobfrob", path=fake_venv_bin)])

with self.subTest("not in venv"):
which_mock.side_effect = [None, "/usr/bin/scoop"]
self.assertEqual("/usr/bin/scoop", runner.which("scoop", context))
which_mock.assert_has_calls(
[
call("scoop", path=fake_venv_bin),
call("scoop"),
]
)

with self.subTest("not found"):
which_mock.side_effect = None
which_mock.return_value = None
self.assertEqual("frobfrob", runner.which("frobfrob", context))
which_mock.assert_has_calls(
[
call("frobfrob", path=fake_venv_bin),
call("frobfrob"),
]
)

@patch("thx.runner.which")
def test_render_command(self, which_mock: Mock) -> None:
which_mock.return_value = "/opt/bin/frobfrob"
Expand All @@ -54,7 +22,7 @@ def test_render_command(self, which_mock: Mock) -> None:
result = runner.render_command("frobfrob check {module}.tests", context, config)
self.assertEqual(("/opt/bin/frobfrob", "check", "alpha.tests"), result)

@patch("thx.runner.shutil.which", return_value=None)
@patch("thx.utils.shutil.which", return_value=None)
def test_prepare_job(self, which_mock: Mock) -> None:
config = Config(values={"module": "beta"})
context = Context(Version("3.9"), Path(), Path())
Expand Down Expand Up @@ -99,7 +67,7 @@ async def test_run_command(self) -> None:
"/fake/binary", "something", stdout=PIPE, stderr=PIPE, env=ANY
)
self.assertIn(
str(runner.venv_bin_path(ctx)),
str(venv_bin_path(ctx.venv)),
exec_mock.call_args.kwargs["env"]["PATH"],
)

Expand Down
34 changes: 34 additions & 0 deletions thx/tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
# Licensed under the MIT License

import asyncio
import platform
import time
from pathlib import Path
from typing import Any, List, Tuple
from unittest import TestCase
from unittest.mock import call, Mock, patch

from .. import utils
from ..types import Context, Job, Step, Version
Expand Down Expand Up @@ -78,6 +80,38 @@ async def foo(value: int, *args: Any, **kwargs: Any) -> int:
self.assertEqual("test message", timing.message)
self.assertEqual(job, timing.job)

@patch("thx.utils.shutil.which")
def test_which(self, which_mock: Mock) -> None:
context = Context(Version("3.10"), Path(), Path("/fake/venv"))
fake_venv_bin = (
"/fake/venv/Scripts" if platform.system() == "Windows" else "/fake/venv/bin"
)
with self.subTest("found"):
which_mock.side_effect = lambda b, path: f"/usr/bin/{b}"
self.assertEqual("/usr/bin/frobfrob", utils.which("frobfrob", context))
which_mock.assert_has_calls([call("frobfrob", path=fake_venv_bin)])

with self.subTest("not in venv"):
which_mock.side_effect = [None, "/usr/bin/scoop"]
self.assertEqual("/usr/bin/scoop", utils.which("scoop", context))
which_mock.assert_has_calls(
[
call("scoop", path=fake_venv_bin),
call("scoop"),
]
)

with self.subTest("not found"):
which_mock.side_effect = None
which_mock.return_value = None
self.assertEqual("frobfrob", utils.which("frobfrob", context))
which_mock.assert_has_calls(
[
call("frobfrob", path=fake_venv_bin),
call("frobfrob"),
]
)

def test_version_match(self) -> None:
test_data: Tuple[Tuple[str, List[Version]], ...] = (
("3.8", [Version("3.8"), Version("3.8.10")]),
Expand Down
21 changes: 21 additions & 0 deletions thx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@
# Licensed under the MIT License

import logging
import platform
import shutil
from asyncio import iscoroutinefunction
from dataclasses import dataclass, field, replace
from functools import wraps
from itertools import zip_longest
from pathlib import Path
from time import monotonic_ns
from typing import Any, Callable, List, Optional, TypeVar

Expand Down Expand Up @@ -109,6 +112,24 @@ def get_timings() -> List[timed]:
return result


def venv_bin_path(venv: Path) -> Path:
if platform.system() == "Windows":
bin_path = venv / "Scripts"
else:
bin_path = venv / "bin"
return bin_path


def which(name: str, context: Context) -> str:
bin_path = venv_bin_path(context.venv).as_posix()
binary = shutil.which(name, path=bin_path)
if binary is None:
binary = shutil.which(name)
if binary is None:
return name
return binary


def version_match(versions: List[Version], target: Version) -> List[Version]:
matches: List[Version] = []
for version in versions:
Expand Down

0 comments on commit c985a75

Please sign in to comment.