Skip to content

Commit

Permalink
added cursor class
Browse files Browse the repository at this point in the history
  • Loading branch information
jreadey committed Jan 31, 2019
1 parent 3dab0c4 commit fc091ca
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 6 deletions.
2 changes: 1 addition & 1 deletion h5pyd/_hl/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,7 +510,7 @@ def __iter__(self):
if shape[0] - i < numrows:
numrows = shape[0] - i
self.log.debug("get {} iter items".format(numrows))
arr = self[i:numrows]
arr = self[i:numrows+i]

yield arr[i%BUFFER_SIZE]

Expand Down
72 changes: 68 additions & 4 deletions h5pyd/_hl/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from __future__ import absolute_import
import numpy
import six
from six.moves import xrange
from .base import _decode
from .dataset import Dataset
from .objectid import DatasetID
Expand All @@ -21,7 +22,54 @@
from .h5type import check_dtype


class Cursor():
"""
Cursor for retreiving rows from a table
"""
def __init__(self, table, query=None, start=None, stop=None):
self._table = table
self._query = query
if start is None:
self._start = 0
else:
self._start = start
if stop is None:
self._stop = table.nrows
else:
self._stop = stop

def __iter__(self):
""" Iterate over the first axis. TypeError if scalar.
BEWARE: Modifications to the yielded data are *NOT* written to file.
"""
nrows = self._table.nrows
# to reduce round trips, grab BUFFER_SIZE items at a time
# TBD: set buffersize based on size of each row
BUFFER_SIZE = 1000

arr = None
query_complete = False

for indx in xrange(self._start, self._stop):
if indx%BUFFER_SIZE == 0:
# grab another buffer
read_count = BUFFER_SIZE
if nrows - indx < read_count:
read_count = nrows - indx
if self._query is None:

arr = self._table[indx:read_count+indx]
else:
# call table to return query result
if query_complete:
arr = None # nothing more to fetch
else:
arr = self._table.read_where(self._query, start=indx, limit=read_count)
if arr is not None and arr.shape[0] < read_count:
query_complete = True # we've gotten all the rows
if arr is not None and indx%BUFFER_SIZE < arr.shape[0]:
yield arr[indx%BUFFER_SIZE]

class Table(Dataset):

Expand Down Expand Up @@ -76,7 +124,7 @@ def read(self, start=None, stop=None, step=None, field=None, out=None):



def read_where(self, condition, condvars=None, field=None, start=None, stop=None, step=None):
def read_where(self, condition, condvars=None, field=None, start=None, stop=None, step=None, limit=None):
"""Read rows from table using pytable-style condition
"""
names = () # todo
Expand Down Expand Up @@ -148,10 +196,19 @@ def readtime_dtype(basetype, names):
try:
self.log.debug("params: {}".format(params))
rsp = self.GET(req, params=params)
count = len(rsp["value"])
values = rsp["value"]
count = len(values)
self.log.info("got {} rows".format(count))
if count > 0:
data.extend(rsp['value'])
if limit is None or count + len(data) <= limit:
# add in all the data
data.extend(values)
else:
# we've hit the limit for number of rows to return
add_count = limit - len(data)
self.log.debug("adding {} from {} to rrows".format(add_count, count))
data.extend(values[:add_count])

# advance to next page
cursor += page_size
except IOError as ioe:
Expand All @@ -165,7 +222,7 @@ def readtime_dtype(basetype, names):
# otherwise, just raise the exception
self.log.info("Unexpected exception: {}".format(ioe.errno))
raise ioe
if cursor >= stop:
if cursor >= stop or limit and len(data) == limit:
self.log.info("completed iteration, returning: {} rows".format(len(data)))
break

Expand All @@ -190,6 +247,13 @@ def readtime_dtype(basetype, names):
arr = numpy.asscalar(arr)

return arr

def create_cursor(self, condition=None, start=None, stop=None):
"""Return a cursor for iteration
"""
return Cursor(self, query=condition, start=start, stop=stop)



def append(self, rows):
""" Append rows to end of table
Expand Down
29 changes: 28 additions & 1 deletion test/test_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,21 @@ def test_create_table(self):

self.assertEqual(table.colnames, ['real', 'img'])
self.assertEqual(table.nrows, count)

num_rows = 0
for row in table:
self.assertEqual(len(row), 2)
num_rows += 1
self.assertEqual(num_rows, count)

# try the same thing using cursor object
cursor = table.create_cursor()
num_rows = 0
for row in cursor:
self.assertEqual(len(row), 2)
num_rows += 1
self.assertEqual(num_rows, count)

arr = table.read(start=5, stop=6)
self.assertEqual(arr.shape, (1,))

Expand Down Expand Up @@ -96,11 +109,25 @@ def test_query_table(self):
# first two columns will come back as bytes, not strs
self.assertEqual(row[col], item[col])

quotes = table.read_where("symbol == b'AAPL'")
condition = "symbol == b'AAPL'"
quotes = table.read_where(condition)
self.assertEqual(len(quotes), 4)
for i in range(4):
quote = quotes[i]
self.assertEqual(quote[0], b'AAPL')

# read up to 2 rows
quotes = table.read_where(condition, limit=2)
self.assertEqual(len(quotes), 2)

# use a query cursor
cursor = table.create_cursor(condition=condition)
num_rows = 0
for row in cursor:
self.assertEqual(len(row), 4)
num_rows += 1
self.assertEqual(num_rows, 4)

f.close()


Expand Down

0 comments on commit fc091ca

Please sign in to comment.