diff --git a/src/pyobsplot/js_modules.py b/src/pyobsplot/js_modules.py index 45210ce..19cd88a 100644 --- a/src/pyobsplot/js_modules.py +++ b/src/pyobsplot/js_modules.py @@ -1,4 +1,6 @@ +import io from functools import partial +from pathlib import Path from typing import Callable, Literal from pyobsplot.obsplot import Obsplot @@ -35,7 +37,16 @@ def plot( if provided, plot is saved to disk to an HTML file instead of displayed as a jupyter widget, by default None """ - format_value = format or _plot_format + format_value = format + if path is not None and not isinstance(path, io.StringIO): + extension = Path(path).suffix.lower()[1:] + allowed_extensions = ["html", "svg", "pdf", "png"] + if extension not in allowed_extensions: + msg = f"Output file extension should be one of {allowed_extensions}" + raise ValueError(msg) + format_value = format_value or extension + + format_value = format_value or _plot_format op = Obsplot(format=format_value) # type: ignore return op(spec, path=path)