Commit 63660ee3 authored by Aurélien Campéas's avatar Aurélien Campéas
Browse files

tsio: provide a faster and more space efficient alternative implementation of the TimeSerie object

parent 831818129b8d
......@@ -12,8 +12,6 @@ from tshistory import schema, tsio
DATADIR = Path(__file__).parent / 'test' / 'data'
DBURI = 'postgresql://localhost:5433/postgres'
tshclass = tsio.TimeSerie
@pytest.fixture(scope='session')
def engine(request):
......@@ -28,8 +26,15 @@ def engine(request):
e = create_engine(uri)
yield e
@pytest.fixture(params=[tsio.TimeSerie,
tsio.BigdataTimeSerie],
scope='session')
def tsh(request, engine):
tsh = request.param()
yield tsh
# build a ts using the logs from another
tsh = tshclass()
log = tsh.log(engine, diff=True)
allnames = set()
for rev in log:
......@@ -46,7 +51,12 @@ def engine(request):
for name in allnames:
assert (tsh.get(engine, name) == tsh.get(engine, 'new_' + name)).all()
@pytest.fixture()
def tsh():
return tsio.TimeSerie()
schema.reset(engine)
meta = schema.MetaData()
# fix the schema module
r, c, cs = schema.make_schema(meta)
schema.registry = r
schema.changeset = c
schema.changeset_series = cs
schema.meta = meta
schema.init(engine)
......@@ -8,6 +8,7 @@ import numpy as np
import pytest
from mock import patch
from tshistory.tsio import BigdataTimeSerie
DATADIR = Path(__file__).parent / 'data'
......@@ -19,8 +20,16 @@ def assert_group_equals(g1, g2):
assert s1.equals(s2)
def remove_metadata(tsrepr):
if 'Freq' in tsrepr or 'Name' in tsrepr:
return tsrepr[:tsrepr.rindex('\n')]
return tsrepr
def assert_df(expected, df):
assert expected.strip() == df.to_string().strip()
exp = remove_metadata(expected.strip())
got = remove_metadata(df.to_string().strip())
assert exp == got
def genserie(start, freq, repeat, initval=None, tz=None, name=None):
......@@ -108,6 +117,8 @@ def test_changeset(engine, tsh):
def test_tstamp_roundtrip(engine, tsh):
if isinstance(tsh, BigdataTimeSerie):
return
ts = genserie(datetime(2017, 10, 28, 23),
'H', 4, tz='UTC')
ts.index = ts.index.tz_convert('Europe/Paris')
......@@ -440,7 +451,22 @@ def test_snapshots(engine, tsh):
for attr in ('diff', 'snapshot'):
df[attr] = df[attr].apply(lambda x: 0 if x is None else len(x))
assert_df("""
if isinstance(tsh, BigdataTimeSerie):
assert_df("""
id diff snapshot
0 1 0 35
1 2 36 0
2 3 36 0
3 4 36 47
4 5 36 0
5 6 36 0
6 7 36 0
7 8 36 59
8 9 36 0
9 10 36 67
""", df)
else:
assert_df("""
id diff snapshot
0 1 0 32
1 2 32 0
......
from datetime import datetime
from contextlib import contextmanager
import logging
import pickle
import zlib
import pandas as pd
import numpy as np
from sqlalchemy import Table, Column, Integer, ForeignKey
from sqlalchemy.sql.expression import select, func, desc
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.dialects.postgresql import JSONB, BYTEA
from tshistory import schema
......@@ -125,7 +127,7 @@ class TimeSerie(object):
csid = self._csid or self._newchangeset(cn, author)
value = {
'csid': csid,
'snapshot': tojson(newts),
'snapshot': self._serialize(newts),
}
# callback for extenders
self._complete_insertion_value(value, extra_scalars)
......@@ -147,8 +149,8 @@ class TimeSerie(object):
csid = self._csid or self._newchangeset(cn, author)
value = {
'csid': csid,
'diff': tojson(diff),
'snapshot': tojson(newsnapshot),
'diff': self._serialize(diff),
'snapshot': self._serialize(newsnapshot),
'parent': tip_id,
}
# callback for extenders
......@@ -285,6 +287,14 @@ class TimeSerie(object):
# /API
# Helpers
# ts serialisation
def _serialize(self, ts):
return tojson(ts)
def _deserialize(self, ts, name):
return fromjson(ts, name)
# serie table handling
def _ts_table_name(self, seriename):
......@@ -416,7 +426,7 @@ class TimeSerie(object):
snapid, snapdata = cn.execute(sql).fetchone()
except TypeError:
return None, None
return snapid, fromjson(snapdata, table.name)
return snapid, self._deserialize(snapdata, table.name)
def _build_snapshot_upto(self, cn, table, qfilter=()):
snapid, snapshot = self._find_snapshot(cn, table, qfilter)
......@@ -444,7 +454,7 @@ class TimeSerie(object):
# initial ts
ts = snapshot
for _, row in alldiffs.iterrows():
diff = fromjson(row['diff'], table.name)
diff = self._deserialize(row['diff'], table.name)
ts = self._apply_diff(ts, diff)
assert ts.index.dtype.name == 'datetime64[ns]' or len(ts) == 0
return ts
......@@ -468,7 +478,7 @@ class TimeSerie(object):
sql = select([table.c.diff])
sql = filtercset(sql)
return fromjson(cn.execute(sql).scalar(), name)
return self._deserialize(cn.execute(sql).scalar(), name)
def _compute_diff(self, fromts, tots):
"""Compute the difference between fromts and tots
......@@ -514,3 +524,32 @@ class TimeSerie(object):
result_ts.sort_index(inplace=True)
result_ts.name = base_ts.name
return result_ts
class BigdataTimeSerie(TimeSerie):
def _table_definition_for(self, seriename):
return Table(
seriename, schema.meta,
Column('id', Integer, primary_key=True),
Column('csid', Integer, ForeignKey('changeset.id'),
index=True, nullable=False),
# constraint: there is either .diff or .snapshot
Column('diff', BYTEA),
Column('snapshot', BYTEA),
Column('parent',
Integer,
ForeignKey('timeserie.%s.id' % seriename,
ondelete='cascade'),
nullable=True,
unique=True,
index=True),
schema='timeserie',
extend_existing=True
)
def _serialize(self, ts):
return zlib.compress(tojson(ts).encode('utf-8'))
def _deserialize(self, ts, name):
return fromjson(zlib.decompress(ts).decode('utf-8'), name)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment