Commit 21948dd5 authored by Aurélien Campéas's avatar Aurélien Campéas
Browse files

use `sqlhelp` to build the sql

parent 338cff91a528
......@@ -14,6 +14,7 @@ setup(name='tshistory',
'deprecated',
'dateutils',
'sqlalchemy',
'sqlhelp',
'click',
'mock',
'pytest_sa_pg',
......
import logging
from pathlib import Path
from tshistory.util import unilist, sqlfile
from sqlhelp import sqlfile
CREATEFILE = Path(__file__).parent / 'schema.sql'
......
......@@ -6,12 +6,10 @@ from pathlib import Path
import pandas as pd
import numpy as np
from tshistory.util import (
SeriesServices,
sqlfile,
sqlp,
sqlq
)
from sqlhelp import sqlfile, select
from tshistory.util import SeriesServices
SCHEMA = Path(__file__).parent / 'snapshot.sql'
......@@ -287,9 +285,9 @@ class Snapshot(SeriesServices):
def cset_heads_query(self, csetfilter=(), order='desc'):
tablename = self.tsh._serie_to_tablename(self.cn, self.seriename)
q = sqlq(
q = select(
'ts.cset', 'ts.snapshot'
).relation(
).table(
f'"{self.tsh.namespace}.timeserie"."{tablename}" as ts',
).join(
f'"{self.tsh.namespace}".changeset as cset on cset.id = ts.cset'
......@@ -298,16 +296,16 @@ class Snapshot(SeriesServices):
if csetfilter:
q.where('ts.cset <= cset.id')
for filtercb in csetfilter:
q.where(filtercb)
filtercb(q)
q.option(f'order by ts.id {order}')
q.order('ts.id', order)
return q
def find(self, csetfilter=(),
from_value_date=None, to_value_date=None):
q = self.cset_heads_query(csetfilter)
q.option('limit 1')
q.limit(1)
try:
csid, cid = q.do(self.cn).fetchone()
......@@ -342,8 +340,8 @@ class Snapshot(SeriesServices):
q = self.cset_heads_query(
(
sqlp('cset.id >= %(mincset)s', mincset=min(csets)),
sqlp('cset.id <= %(maxcset)s', maxcset=max(csets))
lambda q: q.where('cset.id >= %(mincset)s', mincset=min(csets)),
lambda q: q.where('cset.id <= %(maxcset)s', maxcset=max(csets))
),
order='asc'
)
......
......@@ -9,15 +9,13 @@ from pathlib import Path
import pandas as pd
from deprecated import deprecated
from sqlhelp import sqlfile, select
from tshistory.util import (
closed_overlaps,
num2float,
SeriesServices,
start_end,
sqlfile,
sqlp,
sqlq,
tx,
tzaware_serie
)
......@@ -116,7 +114,7 @@ class timeseries(SeriesServices):
csetfilter = []
if revision_date:
csetfilter.append(
sqlp(
lambda q: q.where(
f'cset.insertion_date <= %(idate)s', idate=revision_date
)
)
......@@ -165,9 +163,9 @@ class timeseries(SeriesServices):
def changeset_metadata(self, cn, csid):
assert isinstance(csid, int)
q = sqlq(
q = select(
'metadata'
).relation(
).table(
f'"{self.namespace}".changeset'
).where(
f'id = %(csid)s', csid=csid
......@@ -191,9 +189,9 @@ class timeseries(SeriesServices):
if tablename is None:
return
q = sqlq(
q = select(
'cset.id', 'cset.insertion_date'
).relation(
).table(
f'"{self.namespace}.timeserie"."{tablename}" as ts'
).join(
f'"{self.namespace}".changeset as cset on cset.id = ts.cset'
......@@ -217,7 +215,7 @@ class timeseries(SeriesServices):
todate=to_value_date
)
q.option('order by cset.id')
q.order('cset.id')
revs = q.do(cn).fetchall()
if not revs:
return {}
......@@ -247,7 +245,7 @@ class timeseries(SeriesServices):
idate,
snapshot.find(
csetfilter=[
sqlp('cset.id = %(csid)s', csid=csid)
lambda q: q.where('cset.id = %(csid)s', csid=csid)
],
from_value_date=from_date,
to_value_date=to_date)[1]
......@@ -311,9 +309,9 @@ class timeseries(SeriesServices):
def latest_insertion_date(self, cn, seriename):
tablename = self._serie_to_tablename(cn, seriename)
q = sqlq(
q = select(
'max(insertion_date)'
).relation(
).table(
f'"{self.namespace}".changeset as cset',
f'"{self.namespace}.timeserie"."{tablename}" as tstable'
).where(
......@@ -326,16 +324,14 @@ class timeseries(SeriesServices):
def insertion_dates(self, cn, seriename,
fromdate=None, todate=None):
tablename = self._serie_to_tablename(cn, seriename)
q = sqlq(
q = select(
'insertion_date'
).relation(
).table(
f'"{self.namespace}".changeset as cset',
f'"{self.namespace}.timeserie"."{tablename}" as tstable'
).where(
'cset.id = tstable.cset'
).option(
'order by cset.id'
)
).order('cset.id')
if fromdate:
q.where(
......@@ -365,9 +361,9 @@ class timeseries(SeriesServices):
}
tablename = self._serie_to_tablename(cn, seriename)
assert mode in operators
q = sqlq(
q = select(
'cset'
).relation(
).table(
f'"{self.namespace}.timeserie"."{tablename}" as tstable',
f'"{self.namespace}".changeset as cset '
).where(
......@@ -474,10 +470,10 @@ class timeseries(SeriesServices):
"""
log = []
q = sqlq(
q = select(
'cset.id', 'cset.author', 'cset.insertion_date', 'cset.metadata',
opt='distinct'
).relation(
).table(
f'"{self.namespace}".changeset as cset'
).join(
f'"{self.namespace}".changeset_series as css on css.cset = cset.id',
......@@ -503,9 +499,9 @@ class timeseries(SeriesServices):
if todate:
q.where('cset.insertion_date <= %(todate)s', todate=todate)
q.option('order by cset.id desc')
q.order('cset.id', 'desc')
if limit:
q.option('limit %(limit)s', limit=limit)
q.limit(int(limit))
rset = q.do(cn)
for csetid, author, revdate, meta in rset.fetchall():
......@@ -715,9 +711,9 @@ class timeseries(SeriesServices):
).scalar()
def _changeset_series(self, cn, csid):
q = sqlq(
q = select(
'seriename'
).relation(
).table(
f'"{self.namespace}".registry as reg',
).join(
f'"{self.namespace}".changeset_series as css on css.serie = reg.id'
......
......@@ -16,88 +16,6 @@ from sqlalchemy.engine.base import Engine
from inireader import reader
def sqlfile(path, **kw):
sql = path.read_text()
return sql.format(**kw)
class sqlp:
"""utility to carry an sql fragment plus the potenially needed
parameters
"""
__slots__ = ('sql', 'kw')
def __init__(self, sql, **kw):
self.sql = sql
self.kw = kw
class sqlq:
"""utility to incrementally build an sql query string along with its
parameters
"""
__slots__ = ('head', 'headopt', 'relations', 'joins', 'wheres', 'options', 'kw')
def __init__(self, *head, opt=''):
assert opt in ('', 'distinct')
self.head = list(head)
self.headopt = opt
self.relations = []
self.joins = []
self.wheres = []
self.options = []
self.kw = {}
def select(self, *select):
self.head.append(', '.join(select))
return self
def relation(self, *relations):
self.relations.append(', '.join(relations))
return self
def join(self, *joins, jtype='inner'):
assert jtype in ('inner', 'outer')
for j in joins:
self.joins.append(f'{jtype} join {j}')
return self
def where(self, *wheres, **kw):
self.kw.update(kw)
for where in wheres:
if isinstance(where, sqlp):
self.kw.update(where.kw)
where = where.sql
self.wheres.append(where)
return self
def option(self, option, **kw):
self.options.append(option)
self.kw.update(kw)
return self
@property
def sql(self):
select = f'select {self.headopt} ' + ', '.join(self.head)
froms = 'from ' + ', '.join(self.relations)
join = ' '.join(self.joins)
wheres = 'where ' + ' and '.join(self.wheres) if self.wheres else ''
options = ' '.join(self.options)
sql = [select, froms, join, wheres, options]
return ' '.join(sql)
def __str__(self):
return f'query::[{self.sql}, {self.kw}]'
__repr__ = __str__
def do(self, cn):
return cn.execute(
self.sql,
**self.kw
)
@contextmanager
def tempdir(suffix='', prefix='tmp'):
tmp = tempfile.mkdtemp(suffix=suffix, prefix=prefix)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment