Commit fabdf5c8 authored by Aurélien Campéas's avatar Aurélien Campéas
Browse files

eliminate the temptation of sql injection by using a small sql building utility

`sqlp` carries an sql string fragment plus its needed parameters
parent fcda8d21cedc
......@@ -8,7 +8,8 @@ import numpy as np
from tshistory.util import (
SeriesServices,
sqlfile
sqlfile,
sqlp,
)
SCHEMA = Path(__file__).parent / 'snapshot.sql'
......@@ -216,25 +217,27 @@ class Snapshot(SeriesServices):
f' "{self.tsh.namespace}".changeset as cset'
' where cset.id = ts.cset '
]
params = {}
if csetfilter:
sql.append('and ts.cset <= cset.id ')
for filtercb in csetfilter:
sql.append('and ' + filtercb)
sql.append('and ' + filtercb.sql)
params.update(filtercb.kw)
sql.append(f'order by ts.id {order} ')
return sql
return sql, params
def find(self, csetfilter=(),
from_value_date=None, to_value_date=None):
sql = self.cset_heads_query(csetfilter)
sql, params = self.cset_heads_query(csetfilter)
sql.append('limit 1')
sql = ''.join(sql)
try:
csid, cid = self.cn.execute(sql).fetchone()
csid, cid = self.cn.execute(sql, **params).fetchone()
except TypeError:
# this happens *only* because of the from/to restriction
return None, None
......@@ -262,14 +265,18 @@ class Snapshot(SeriesServices):
csets = [rev for rev, _ in revs if rev is not None]
# csid -> heads
sql = self.cset_heads_query((f'cset.id >= {min(csets)}',
f'cset.id <= {max(csets)}'),
order='asc')
sql, params = self.cset_heads_query(
(
sqlp('cset.id >= %(mincset)s', mincset=min(csets)),
sqlp('cset.id <= %(maxcset)s', maxcset=max(csets))
),
order='asc'
)
sql = ''.join(sql)
cset_snap_map = {
row.cset: row.snapshot
for row in self.cn.execute(sql).fetchall()
for row in self.cn.execute(sql, **params).fetchall()
}
rawchunks = self.allchunks(
sorted(cset_snap_map.values()),
......
......@@ -16,6 +16,7 @@ from tshistory.util import (
SeriesServices,
start_end,
sqlfile,
sqlp,
tx,
tzaware_serie
)
......@@ -114,7 +115,9 @@ class timeseries(SeriesServices):
csetfilter = []
if revision_date:
csetfilter.append(
f'cset.insertion_date <= \'{revision_date.isoformat()}\''
sqlp(
f'cset.insertion_date <= %(idate)s', idate=revision_date
)
)
snap = Snapshot(cn, self, seriename)
_, current = snap.find(csetfilter=csetfilter,
......@@ -231,10 +234,14 @@ class timeseries(SeriesServices):
to_date = idate + deltaafter
series.append((
idate,
snapshot.find(csetfilter=[f'cset.id = {csid}'],
from_value_date=from_date,
to_value_date=to_date)[1]
))
snapshot.find(
csetfilter=[
sqlp('cset.id = %(csid)s', csid=csid)
],
from_value_date=from_date,
to_value_date=to_date)[1]
)
)
else:
series = snapshot.findall(revs,
from_value_date,
......
......@@ -21,6 +21,14 @@ def sqlfile(path, **kw):
return sql.format(**kw)
class sqlp:
__slots__ = ('sql', 'kw')
def __init__(self, sql, **kw):
self.sql = f'{sql.strip()} '
self.kw = kw
@contextmanager
def tempdir(suffix='', prefix='tmp'):
tmp = tempfile.mkdtemp(suffix=suffix, prefix=prefix)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment