Skip to content

Commit

Permalink
Implement pairwise element dragging
Browse files Browse the repository at this point in the history
This adds two new APIs - `drag_pair_forwards` and its companion
`drag_pair_backwards` which allow dragging 'pairs' of elements within a
form.

An easy example is the key-value pairs within a map which typically are
meant to stay together when reordering.

This can, however, be used for any form. This means if you have some
vector containing logical pairs (because of the actual semantics of your
code) you can use a dedicated keybinding to drag them around together.

This change also introduces a new config option at:

`dragging.auto_drag_pairs = true|false`

which will alter the behaviour of the existing `drag_element_forwards`
and `drag_element_backwards` APIs in order to try infer whether they are
contained within a node that is made up of pairs.

For example if this setting is `true` and a `drag_element_forwards` is
used on the keys of a map then they will be dragged pairwise.

This new config option defaults to `true` under the assumption that this
is generally the desired behaviour.
  • Loading branch information
julienvincent committed Oct 11, 2024
1 parent de1c08f commit ea35080
Show file tree
Hide file tree
Showing 9 changed files with 269 additions and 10 deletions.
15 changes: 15 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,13 @@ paredit.setup({
-- this case it will place the cursor on the moved edge
cursor_behaviour = "auto", -- remain, follow, auto

dragging = {
-- If set to `true` paredit will attempt to infer if an element being
-- dragged is part of a 'paired' form like as a map. If so then the element
-- will be dragged along with it's pair.
auto_drag_pairs = true,
},

indent = {
-- This controls how nvim-paredit handles indentation when performing operations which
-- should change the indentation of the form (such as when slurping or barfing).
Expand All @@ -88,6 +95,9 @@ paredit.setup({
[">e"] = { paredit.api.drag_element_forwards, "Drag element right" },
["<e"] = { paredit.api.drag_element_backwards, "Drag element left" },

[">p"] = { api.drag_pair_forwards, "Drag element pairs right" },
["<p"] = { api.drag_pair_backwards, "Drag element pairs left" },

[">f"] = { paredit.api.drag_form_forwards, "Drag form right" },
["<f"] = { paredit.api.drag_form_backwards, "Drag form left" },

Expand Down Expand Up @@ -287,6 +297,9 @@ require("nvim-paredit").setup({
-- Accepts a Treesitter node and should return true or false depending on whether the given node
-- can be considered a 'comment'
node_is_comment = function(node) end,
-- Accepts a Treesitter node and should return a boolean indicating whether or not the given node contains
-- 'paired' elements. For example a clojure map (`{:a 1 :b 2}`)
node_contains_pairs = function(node) end,
-- Accepts a Treesitter node representing a form and should return the 'edges' of the node. This
-- includes the node text and the range covered by the node
get_form_edges = function(node)
Expand Down Expand Up @@ -332,6 +345,8 @@ paredit.api.slurp_forwards()
- **`barf_backwards`**
- **`drag_element_forwards`**
- **`drag_element_backwards`**
- **`drag_pair_forwards`**
- **`drag_pair_backwards`**
- **`drag_form_forwards`**
- **`drag_form_backwards`**
- **`raise_element`**
Expand Down
95 changes: 89 additions & 6 deletions lua/nvim-paredit/api/dragging.lua
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
local traversal = require("nvim-paredit.utils.traversal")
local common = require("nvim-paredit.utils.common")
local ts = require("nvim-treesitter.ts_utils")
local config = require("nvim-paredit.config")
local langs = require("nvim-paredit.lang")

local M = {}
Expand Down Expand Up @@ -44,24 +46,65 @@ function M.drag_form_backwards()
ts.swap_nodes(root, sibling, buf, true)
end

function M.drag_element_forwards()
local function find_current_pair(pairs, current_node)
for i, pair in ipairs(pairs) do
for _, node in ipairs(pair) do
if node:equal(current_node) then
return i, pair
end
end
end
end

local function drag_pairs(opts)
local lang = langs.get_language_api()
local current_node = lang.get_node_root(ts.get_node_at_cursor())
if not current_node then
return
end

local sibling = current_node:next_named_sibling()
if not sibling then
local direction = 1
if opts.reversed then
direction = -1
end

local parent = current_node:parent()
if not parent then
return
end

local children = traversal.get_children_ignoring_comments(parent, {
lang = lang,
})
local pairs = common.chunk_table(children, 2)
local chunk_index, pair = find_current_pair(pairs, current_node)

local corresponding_pair = pairs[chunk_index + direction]
if not corresponding_pair then
return
end

local buf = vim.api.nvim_get_current_buf()
ts.swap_nodes(current_node, sibling, buf, true)
ts.swap_nodes(pair[2], corresponding_pair[2], buf, true)
ts.swap_nodes(pair[1], corresponding_pair[1], buf, true)
end

function M.drag_element_backwards()
local function drag_element(opts)
local lang = langs.get_language_api()
local current_node = lang.get_node_root(ts.get_node_at_cursor())

local sibling = current_node:prev_named_sibling()
local parent = current_node:parent()
if opts.dragging.auto_drag_pairs and lang.node_contains_pairs and lang.node_contains_pairs(parent) then
return drag_pairs(opts)
end

local sibling
if opts.reversed then
sibling = current_node:prev_named_sibling()
else
sibling = current_node:next_named_sibling()
end

if not sibling then
return
end
Expand All @@ -70,4 +113,44 @@ function M.drag_element_backwards()
ts.swap_nodes(current_node, sibling, buf, true)
end

function M.drag_element_forwards(opts)
local drag_opts = vim.tbl_deep_extend(
"force",
{
dragging = config.config.dragging or {},
},
opts or {},
{
reversed = false,
}
)
drag_element(drag_opts)
end

function M.drag_element_backwards(opts)
local drag_opts = vim.tbl_deep_extend(
"force",
{
dragging = config.config.dragging or {},
},
opts or {},
{
reversed = true,
}
)
drag_element(drag_opts)
end

function M.drag_pair_forwards()
drag_pairs({
reversed = false,
})
end

function M.drag_pair_backwards()
drag_pairs({
reversed = true,
})
end

return M
4 changes: 4 additions & 0 deletions lua/nvim-paredit/api/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ local M = {

drag_element_forwards = dragging.drag_element_forwards,
drag_element_backwards = dragging.drag_element_backwards,

drag_pair_forwards = dragging.drag_pair_forwards,
drag_pair_backwards = dragging.drag_pair_backwards,

drag_form_forwards = dragging.drag_form_forwards,
drag_form_backwards = dragging.drag_form_backwards,

Expand Down
11 changes: 10 additions & 1 deletion lua/nvim-paredit/defaults.lua
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ local unwrap = require("nvim-paredit.api.unwrap")
local M = {}

M.default_keys = {
["<localleader>@"] = { unwrap.unwrap_form_under_cursor, "Splice sexp", },
["<localleader>@"] = { unwrap.unwrap_form_under_cursor, "Splice sexp" },

[">)"] = { api.slurp_forwards, "Slurp forwards" },
[">("] = { api.barf_backwards, "Barf backwards" },
Expand All @@ -15,6 +15,9 @@ M.default_keys = {
[">e"] = { api.drag_element_forwards, "Drag element right" },
["<e"] = { api.drag_element_backwards, "Drag element left" },

[">p"] = { api.drag_pair_forwards, "Drag element pairs right" },
["<p"] = { api.drag_pair_backwards, "Drag element pairs left" },

[">f"] = { api.drag_form_forwards, "Drag form right" },
["<f"] = { api.drag_form_backwards, "Drag form left" },

Expand Down Expand Up @@ -107,6 +110,12 @@ M.default_keys = {
M.defaults = {
use_default_keys = true,
cursor_behaviour = "auto", -- remain, follow, auto
dragging = {
-- If set to `true` paredit will attempt to infer if an element being
-- dragged is part of a 'paired' form like as a map. If so then the element
-- will be dragged along with it's pair.
auto_drag_pairs = true,
},
indent = {
enabled = false,
indentor = require("nvim-paredit.indentation.native").indentor,
Expand Down
40 changes: 40 additions & 0 deletions lua/nvim-paredit/lang/clojure.lua
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,17 @@ local form_types = {
"anon_fn_lit",
}

local base_pair_types = {
map_lit = 0,
-- case = 2,
-- cond = 1,
-- condp = 3,
}

local fn_pair_names = {
vec_lit = { "let", "loop", "binding", "with-open", "with-redefs" },
}

M.whitespace_chars = { " ", "," }

local function find_next_parent_form(current_node)
Expand Down Expand Up @@ -59,6 +70,31 @@ function M.node_is_comment(node)
return node:type() == "comment"
end

function M.node_contains_pairs(node)
local node_type = node:type()
for type, offset in pairs(base_pair_types) do
if type == node_type then
return true, offset
end
end

local fn = traversal.get_prev_sibling_ignoring_comments(node, {
lang = M,
})
if not fn then
return false
end

local fn_name = vim.treesitter.get_node_text(fn, 0)

local fn_names = fn_pair_names[node:type()]
if fn_names and common.included_in_table(fn_names, fn_name) then
return true
end

return false
end

function M.get_form_edges(node)
local outer_range = { node:range() }

Expand All @@ -67,19 +103,23 @@ function M.get_form_edges(node)
local left_bracket_range = { form:field("open")[1]:range() }
local right_bracket_range = { form:field("close")[1]:range() }

-- stylua: ignore
local left_range = {
outer_range[1], outer_range[2],
left_bracket_range[3], left_bracket_range[4]
}
-- stylua: ignore
local right_range = {
right_bracket_range[1], right_bracket_range[2],
outer_range[3], outer_range[4],
}

-- stylua: ignore
local left_text = vim.api.nvim_buf_get_text(0,
left_range[1], left_range[2],
left_range[3], left_range[4],
{})
-- stylua: ignore
local right_text = vim.api.nvim_buf_get_text(0,
right_range[1], right_range[2],
right_range[3], right_range[4],
Expand Down
5 changes: 2 additions & 3 deletions lua/nvim-paredit/lang/init.lua
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
local common = require("nvim-paredit.utils.common")

local langs = {
clojure = require("nvim-paredit.lang.clojure"),
}
Expand All @@ -14,13 +12,14 @@ local function keys(tbl)
return result
end

--- @return table<string, function>
function M.get_language_api()
for l in string.gmatch(vim.bo.filetype, "[^.]+") do
if langs[l] ~= nil then
return langs[l]
end
end
return nil
error("Could not find language extension for filetype " .. vim.bo.filetype, vim.log.levels.ERROR)
end

function M.add_language_extension(filetype, api)
Expand Down
14 changes: 14 additions & 0 deletions lua/nvim-paredit/utils/common.lua
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,20 @@ function M.included_in_table(table, item)
return false
end

function M.chunk_table(tbl, chunk_size)
local result = {}
for i = 1, #tbl, chunk_size do
local chunk = {}
for j = 0, chunk_size - 1 do
if tbl[i + j] then
table.insert(chunk, tbl[i + j])
end
end
table.insert(result, chunk)
end
return result
end

-- Compares the two given { col, row } position tuples and returns -1/0/1 depending
-- on whether `a` is less than, equal to or greater than `b`
--
Expand Down
16 changes: 16 additions & 0 deletions lua/nvim-paredit/utils/traversal.lua
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,22 @@ function M.find_nearest_form(current_node, opts)
end
end

function M.get_children_ignoring_comments(node, opts)
local children = {}

local index = 0
local child = node:named_child(index)
while child do
if not child:extra() and not opts.lang.node_is_comment(child) then
table.insert(children, child)
end
index = index + 1
child = node:named_child(index)
end

return children
end

local function get_child_ignoring_comments(node, index, opts)
if index < 0 or index >= node:named_child_count() then
return
Expand Down
Loading

0 comments on commit ea35080

Please sign in to comment.