diff --git a/Makefile b/Makefile index 84982cde..c317f55f 100644 --- a/Makefile +++ b/Makefile @@ -158,7 +158,7 @@ FLAGS += -DTERRA_LLVM_HEADERS_HAVE_NDEBUG endif LIBOBJS = tkind.o tcompiler.o tllvmutil.o tcwrapper.o tinline.o terra.o lparser.o lstring.o lobject.o lzio.o llex.o lctype.o treadnumber.o tcuda.o tdebug.o tinternalizedfiles.o lj_strscan.o -LIBLUA = terralib.lua strict.lua cudalib.lua asdl.lua +LIBLUA = terralib.lua strict.lua cudalib.lua asdl.lua terralist.lua EXEOBJS = main.o linenoise.o diff --git a/msvc/Makefile b/msvc/Makefile index 51bdc646..80fa00fa 100755 --- a/msvc/Makefile +++ b/msvc/Makefile @@ -110,12 +110,13 @@ $(LUAJIT) $(TERRA_DIR)\release\include\terra\lua.h: "$(LUAJIT_DIR)\src\luajit.c" copy lauxlib.h "$(TERRA_DIR)\release\include\terra" copy luaconf.h "$(TERRA_DIR)\release\include\terra" -"$(BUILD)\terralib.h" "$(BUILD)\strict.h" "$(BUILD)\cudalib.h": "$(SRC)\terralib.lua" "$(SRC)\strict.lua" "$(SRC)\cudalib.lua" $(LUAJIT) "$(TERRA_DIR)\release\include\terra\lua.h" +"$(BUILD)\terralib.h" "$(BUILD)\strict.h" "$(BUILD)\cudalib.h" "$(BUILD)\asdl.h" "$(BUILD)\terralist.h": "$(SRC)\terralib.lua" "$(SRC)\strict.lua" "$(SRC)\cudalib.lua" "$(SRC)\asdl.lua" "$(SRC)\terralist.lua" $(LUAJIT) "$(TERRA_DIR)\release\include\terra\lua.h" set LUA_PATH=$(LUAJIT_DIR)\src\?.lua $(LUAJIT) -bg "$(SRC)\terralib.lua" "$(BUILD)\terralib.h" $(LUAJIT) -bg "$(SRC)\strict.lua" "$(BUILD)\strict.h" $(LUAJIT) -bg "$(SRC)\cudalib.lua" "$(BUILD)\cudalib.h" $(LUAJIT) -bg "$(SRC)\asdl.lua" "$(BUILD)\asdl.h" + $(LUAJIT) -bg "$(SRC)\terralist.lua" "$(BUILD)\terralist.h" "$(BUILD)\clangpaths.h": $(LUAJIT) "$(SRC)\genclangpaths.lua" cd "$(TERRA_DIR)" diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index ad30e1e3..d71641f2 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -43,6 +43,7 @@ list(APPEND TERRA_LIB_LUA_SRC strict.lua cudalib.lua asdl.lua + terralist.lua ) foreach(LUA_SRC ${TERRA_LIB_LUA_SRC}) diff --git a/src/asdl.lua b/src/asdl.lua index 5c2c4b3f..f913f92f 100644 --- a/src/asdl.lua +++ b/src/asdl.lua @@ -1,40 +1,4 @@ -local List = {} -List.__index = List -for k,v in pairs(table) do - List[k] = v -end -setmetatable(List, { __call = function(self, lst) - if lst == nil then - lst = {} - end - return setmetatable(lst,self) -end}) -function List:map(fn,...) - local l = List() - if type(fn) == "function" then - for i,v in ipairs(self) do - l[i] = fn(v,...) - end - else - for i,v in ipairs(self) do - local sel = v[fn] - if type(sel) == "function" then - l[i] = sel(v,...) - else - l[i] = sel - end - end - end - return l -end -function List:insertall(elems) - for i,e in ipairs(elems) do - self:insert(e) - end -end -function List:isclassof(exp) - return getmetatable(exp) == self -end +local List = require("terralist") local Context = {} function Context:__index(idx) diff --git a/src/terra.cpp b/src/terra.cpp index b7443b29..ee0c6320 100644 --- a/src/terra.cpp +++ b/src/terra.cpp @@ -146,6 +146,7 @@ int terra_lualoadstring(lua_State * L) { #include "terralib.h" #include "strict.h" #include "asdl.h" +#include "terralist.h" int terra_loadandrunbytecodes(lua_State * L, const unsigned char * bytecodes, size_t size, const char * name) { return luaL_loadbuffer(L, (const char *)bytecodes, size, name) @@ -245,6 +246,7 @@ int terra_initwithoptions(lua_State * L, terra_Options * options) { return err; } err = terra_loadandrunbytecodes(T->L,(const unsigned char*)luaJIT_BC_strict,luaJIT_BC_strict_SIZE, "strict.lua") + || terra_loadandrunbytecodes(T->L,(const unsigned char*)luaJIT_BC_terralist,luaJIT_BC_terralist_SIZE, "terralist.lua") || terra_loadandrunbytecodes(T->L,(const unsigned char*)luaJIT_BC_asdl,luaJIT_BC_asdl_SIZE, "asdl.lua") #ifndef TERRA_EXTERNAL_TERRALIB || terra_loadandrunbytecodes(T->L,(const unsigned char*)luaJIT_BC_terralib,luaJIT_BC_terralib_SIZE, "terralib.lua"); diff --git a/src/terralist.lua b/src/terralist.lua new file mode 100644 index 00000000..77900c4b --- /dev/null +++ b/src/terralist.lua @@ -0,0 +1,405 @@ +--[[ +-- the List type is a plain Lua table with additional methods that come from: +-- 1. all the methods in Lua's 'table' global +-- 2. a list of higher-order functions based on sml's (fairly minimal) list type. + +-- For each function that is an argument of a high-order List function can be either: +-- 1. a real Lua function +-- 2. a string of an operator "+" (see op table) +-- 3. a string that specifies a field or method to call on the object +-- local mylist = List { a,b,c } +-- mylist:map("foo") -- selects the fields: a.foo, b.foo, c.foo, etc. +-- -- if a.foo is a function it will be treated as a method a:foo() +-- extra arguments to the higher-order function are passed through to these function. +-- rationale: Lua inline function syntax is verbose, this functionality avoids +-- inline functions in many cases + +list:sub(i,j) -- Lua's string.sub, but for lists +list:rev() : List[A] -- reverse list +list:app(fn : A -> B) : {} -- app fn to every element +list:map(fn : A -> B) : List[B] -- apply map to every element resulting in new list +list:filter(fn : A -> boolean) : List[A] -- new list with elements were fn(e) is true +list:flatmap(fn : A -> List[B]) : List[B] -- apply map to every element, resulting in lists which are all concatenated together +list:find(fn : A -> boolean) : A? -- find the first element in list satisfying condition +list:partition(fn : A -> {K,V}) : Map[ K,List[V] ] -- apply k,v = fn(e) to each element and group the values 'v' into bin of the same 'k' +list:fold(init : B,fn : {B,A} -> B) -> B -- recurrence fn(a[2],fn(a[1],init)) ... +list:reduce(fn : {B,A} -> B) -> B -- recurrence fn(a[3],fn(a[2],a[1])) +list:reduceor(init : B,fn : {B,A} -> B) -> B -- recurrence fn(a[3],fn(a[2],a[1])) or init if the list is empty +list:exists(fn : A -> boolean) : boolean -- is any fn(e) true in list +list:all(fn : A -> boolean) : boolean -- are all fn(e) true in list + +Every function that takes a higher-order function also has a 'i' variant that +Also provides the list index to the function: + +list:mapi(fn : {int,A} -> B) -> List[B] +]] + +local List = {} +List.__index = List +for k,v in pairs(table) do + List[k] = v +end +setmetatable(List, { __call = function(self, lst) + if lst == nil then + lst = {} + end + return setmetatable(lst,self) +end}) +function List:isclassof(exp) + return getmetatable(exp) == self +end +function List:insertall(elems) + for i,e in ipairs(elems) do + self:insert(e) + end +end +function List:rev() + local l,N = List(),#self + for i = 1,N do + l[i] = self[N-i+1] + end + return l +end +function List:sub(i,j) + local N = #self + if not j then + j = N + end + if i < 0 then + i = N+i+1 + end + if j < 0 then + j = N+j+1 + end + local l = List() + for c = i,j do + l:insert(self[c]) + end + return l +end +function List:__tostring() + return ("{%s}"):format(self:map(tostring):concat(",")) +end + +local OpTable = { +["+"] = function(x,y) return x + y end; +["*"] = function(x,y) return x * y end; +["/"] = function(x,y) return x / y end; +["%"] = function(x,y) return x % y end; +["^"] = function(x,y) return x ^ y end; +[".."] = function(x,y) return x .. y end; +["<"] = function(x,y) return x < y end; +[">"] = function(x,y) return x > y end; +["<="] = function(x,y) return x <= y end; +[">="] = function(x,y) return x >= y end; +["~="] = function(x,y) return x ~= y end; +["~="] = function(x,y) return x == y end; +["and"] = function(x,y) return x and y end; +["or"] = function(x,y) return x or y end; +["not"] = function(x) return not x end; +["-"] = function(x,y) + if not y then + return -x + else + return x - y + end +end +} + +local function selector(key) + local fn = OpTable[key] + if fn then return fn end + return function(v,...) + local sel = v[key] + if type(sel) == "function" then + return sel(v,...) + else + return sel + end + end +end +local function selectori(key) + local fn = OpTable[key] + if fn then return fn end + return function(i,v,...) + local sel = v[key] + if type(sel) == "function" then + return sel(i,v,...) + else + return sel + end + end +end + +function List:mapi(fn,...) + if type(fn) ~= "function" then + fn = selectori(fn) + end + local l = List() + for i,v in ipairs(self) do + l[i] = fn(i,v,...) + end + return l +end +function List:map(fn,...) + if type(fn) ~= "function" then + fn = selector(fn) + end + local l = List() + for i,v in ipairs(self) do + l[i] = fn(v,...) + end + return l +end + + +function List:appi(fn,...) + if type(fn) ~= "function" then + fn = selectori(fn) + end + for i,v in ipairs(self) do + fn(i,v,...) + end +end +function List:app(fn,...) + if type(fn) ~= "function" then + fn = selector(fn) + end + for i,v in ipairs(self) do + fn(v,...) + end +end + + +function List:filteri(fn,...) + if type(fn) ~= "function" then + fn = selectori(fn) + end + local l = List() + for i,v in ipairs(self) do + if fn(i,v,...) then + l:insert(v) + end + end + return l +end +function List:filter(fn,...) + if type(fn) ~= "function" then + fn = selector(fn) + end + local l = List() + for i,v in ipairs(self) do + if fn(v,...) then + l:insert(v) + end + end + return l +end + +function List:flatmapi(fn,...) + if type(fn) ~= "function" then + fn = selectori(fn) + end + local l = List() + for i,v in ipairs(self) do + local r = fn(i,v,...) + for j,v2 in ipairs(r) do + l:insert(v2) + end + end + return l +end +function List:flatmap(fn,...) + if type(fn) ~= "function" then + fn = selector(fn) + end + local l = List() + for i,v in ipairs(self) do + local r = fn(v,...) + for j,v2 in ipairs(r) do + l:insert(v2) + end + end + return l +end + +function List:findi(fn,...) + if type(fn) ~= "function" then + fn = selectori(fn) + end + local l = List() + for i,v in ipairs(self) do + if fn(i,v,...) then + return v + end + end + return nil +end +function List:find(fn,...) + if type(fn) ~= "function" then + fn = selector(fn) + end + local l = List() + for i,v in ipairs(self) do + if fn(v,...) then + return v + end + end + return nil +end + +function List:partitioni(fn,...) + if type(fn) ~= "function" then + fn = selectori(fn) + end + local m = {} + for i,v in ipairs(self) do + local k,v2 = fn(i,v,...) + local l = m[k] + if not l then + l = List() + m[k] = l + end + l:insert(v2) + end + return m +end +function List:partition(fn,...) + if type(fn) ~= "function" then + fn = selector(fn) + end + local m = {} + for i,v in ipairs(self) do + local k,v2 = fn(v,...) + local l = m[k] + if not l then + l = List() + m[k] = l + end + l:insert(v2) + end + return m +end + +function List:foldi(init,fn,...) + if type(fn) ~= "function" then + fn = selectori(fn) + end + local s = init + for i,v in ipairs(self) do + s = fn(i,s,v,...) + end + return s +end +function List:fold(init,fn,...) + if type(fn) ~= "function" then + fn = selector(fn) + end + local s = init + for i,v in ipairs(self) do + s = fn(s,v,...) + end + return s +end + +function List:reducei(fn,...) + if type(fn) ~= "function" then + fn = selectori(fn) + end + local N = #self + assert(N > 0, "reduce requires non-empty list") + local s = self[1] + for i = 2,N do + s = fn(i,s,self[i],...) + end + return s +end +function List:reduce(fn,...) + if type(fn) ~= "function" then + fn = selector(fn) + end + local N = #self + assert(N > 0, "reduce requires non-empty list") + local s = self[1] + for i = 2,N do + s = fn(s,self[i],...) + end + return s +end + +function List:reduceori(init,fn,...) + if type(fn) ~= "function" then + fn = selectori(fn) + end + local N = #self + if N == 0 then + return init + end + local s = self[1] + for i = 2,N do + s = fn(i,s,self[i],...) + end + return s +end +function List:reduceor(init,fn,...) + if type(fn) ~= "function" then + fn = selector(fn) + end + local N = #self + if N == 0 then + return init + end + local s = self[1] + for i = 2,N do + s = fn(s,self[i],...) + end + return s +end + +function List:existsi(fn,...) + if type(fn) ~= "function" then + fn = selectori(fn) + end + for i,v in ipairs(self) do + if fn(i,v,...) then + return true + end + end + return false +end +function List:exists(fn,...) + if type(fn) ~= "function" then + fn = selector(fn) + end + for i,v in ipairs(self) do + if fn(v,...) then + return true + end + end + return false +end + +function List:alli(fn,...) + if type(fn) ~= "function" then + fn = selectori(fn) + end + for i,v in ipairs(self) do + if not fn(i,v,...) then + return false + end + end + return true +end +function List:all(fn,...) + if type(fn) ~= "function" then + fn = selector(fn) + end + for i,v in ipairs(self) do + if not fn(v,...) then + return false + end + end + return true +end + +package.loaded["terralist"] = List diff --git a/tests/list.t b/tests/list.t new file mode 100644 index 00000000..86a8859e --- /dev/null +++ b/tests/list.t @@ -0,0 +1,87 @@ +local List = require("terralist") + + +local a = List { 2,3,4 } + + +local b = a:map(function(x) return x + 1 end) +for i,v in ipairs(b) do + assert(i+2 == v) +end +local c = b:mapi(function(i,x) return x - i end) +for i,v in ipairs(c) do + assert(v == 2) +end + +local s = 0 +a:app(function(x) s = s + x end) +assert(s == 9) +a:appi(function(i,x) s = s + i + x end) +assert(s == 24,s) + +local r = a:filteri(function(i,x) return (i+x) <= 5 end) +assert(#r == 2 and r[2] == 3) +local r = a:filter(function(x) return x % 2 == 0 end) +assert(#r == 2 and r[2] == 4) + +local r = a:flatmapi(function(i,x) return List{x,x+i+1} end) +assert(#r == 2*#a and r[4] == 6) + +local r = a:flatmapi(function(i,x) return List{x,x+1} end) +assert(#r == 2*#a and r[4] == 4) + +assert(a:findi(function(i,x) return i == 2 and x == 3 end) == 3) +assert(a:findi(function() return false end) == nil) + +assert(a:find(function(x) return x == 2 end) == 2) +assert(a:find(function(x) return false end) == nil) + +local r = a:partitioni(function(i,x) return i % 2, x end) +assert(#r[0] == 1 and r[0][1] == 3 and #r[1] == 2,#r[1]) +local r = a:partition(function(x) return x % 2, x end) +assert(#r[0] == 2 and r[0][1] == 2 and #r[1] == 1,#r[1]) +assert(a:rev()[1] == 4 and a:rev()[3] == 2) + +local v34 = a:sub(2) + +assert(v34[1] == 3 and v34[2] == 4) + +local v3 = a:sub(2,-2) +assert(v3[1] == 3 and #v3 == 1) + +local v34 = a:sub(-2) +assert(v34[1] == 3 and v34[2] == 4) +assert(15 == a:foldi(0,function(i,s,x) return i+s+x end)) +assert(9 == a:fold(0,"+")) +assert(24 == a:fold(1,"*")) + +assert(9 == a:reducei(function(i,s,x) return s + x end)) + +local Wrap +local Wrapper = {} +function Wrapper:plus(rhs) return Wrap(self.a + rhs.a) end +function Wrap(x) + setmetatable({ a = x }, {__index = Wrapper}) +end + +local A = { a = 3 } +local B = { a = 4 } +function A:plus(rhs,c) return self.a + rhs.a + c end +function B:plus(rhs,c) return self.a + rhs.a + c end + +local l = List { A, B} +assert(7 + 4 == l:reduce("plus",4)) + +assert(a:existsi(function(i,x) return i == 2 and x == 3 end)) +assert(a:exists(function(x) return x == 3 end)) +assert(not a:exists(function(x) return x == -1 end)) + +assert(a:alli(function(i,x) return type(x) == "number" end)) +assert(not a:alli(function(i,x) return x == 2 end)) +assert(a:all(tonumber)) +assert(not a:all("not")) +assert(not a:exists("not")) +assert(tostring(List{1,2}) == "{1,2}") +assert(List{}:reduceor(3,function()end) == 3) + +assert(List{1,5}:reduceori(3,function(i,x,y) return x+y end) == 6)