db.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466
  1. import os
  2. import sys
  3. import re
  4. import time
  5. import logging
  6. import hashlib
  7. import threading
  8. from configparser import ConfigParser
  9. from pathlib import Path
  10. from contextlib import contextmanager
  11. from datetime import timedelta
  12. from typing import NamedTuple, List
  13. from docopt import docopt
  14. from natural.date import compress as compress_date
  15. import pandas as pd
  16. from more_itertools import peekable
  17. import psycopg2, psycopg2.errorcodes
  18. from psycopg2 import sql
  19. from psycopg2.pool import ThreadedConnectionPool
  20. from sqlalchemy import create_engine
  21. import sqlparse
  22. import git
  23. _log = logging.getLogger(__name__)
  24. # Meta-schema for storing stage and file status in the database
  25. _ms_path = Path(__file__).parent.parent / 'schemas' / 'meta-schema.sql'
  26. meta_schema = _ms_path.read_text()
  27. _pool = None
  28. _engine = None
  29. # DB configuration info
  30. class DBConfig:
  31. host: str
  32. port: str
  33. database: str
  34. user: str
  35. password: str
  36. @classmethod
  37. def load(cls):
  38. repo = git.Repo(search_parent_directories=True)
  39. cfg = ConfigParser()
  40. _log.debug('reading config from db.cfg')
  41. cfg.read([repo.working_tree_dir + '/db.cfg'])
  42. branch = repo.head.reference.name
  43. _log.info('reading database config for branch %s', branch)
  44. if branch in cfg:
  45. section = cfg[branch]
  46. else:
  47. _log.debug('No configuration for branch %s, using default', branch)
  48. section = cfg['DEFAULT']
  49. dbc = cls()
  50. dbc.host = section.get('host', 'localhost')
  51. dbc.port = section.get('port', None)
  52. dbc.database = section.get('database', None)
  53. dbc.user = section.get('user', None)
  54. dbc.password = section.get('password', None)
  55. if dbc.database is None:
  56. _log.error('No database specified for branch %s', branch)
  57. raise RuntimeError('no database specified')
  58. return dbc
  59. def url(self) -> str:
  60. url = 'postgresql://'
  61. if self.user:
  62. url += self.user
  63. if self.password:
  64. url += ':' + self.password
  65. url += '@'
  66. url += self.host
  67. if self.port:
  68. url += ':' + self.port
  69. url += '/' + self.database
  70. return url
  71. def db_url():
  72. "Get the URL to connect to the database."
  73. if 'DB_URL' in os.environ:
  74. _log.info('using env var DB_URL')
  75. return os.environ['DB_URL']
  76. config = DBConfig.load()
  77. _log.info('using database %s', config.database)
  78. return config.url()
  79. @contextmanager
  80. def connect():
  81. "Connect to a database. This context manager yields the connection, and closes it when exited."
  82. global _pool
  83. if _pool is None:
  84. _log.info('connecting to %s', db_url())
  85. _pool = ThreadedConnectionPool(1, 5, db_url())
  86. conn = _pool.getconn()
  87. try:
  88. yield conn
  89. finally:
  90. _pool.putconn(conn)
  91. def engine():
  92. "Get an SQLAlchemy engine"
  93. global _engine
  94. if _engine is None:
  95. _log.info('connecting to %s', db_url())
  96. _engine = create_engine(db_url())
  97. return _engine
  98. def _tokens(s, start=-1, skip_ws=True, skip_cm=True):
  99. i, t = s.token_next(start, skip_ws=skip_ws, skip_cm=skip_cm)
  100. while t is not None:
  101. yield t
  102. i, t = s.token_next(i, skip_ws=skip_ws, skip_cm=skip_cm)
  103. def describe_statement(s):
  104. "Describe an SQL statement. This utility function is used to summarize statements."
  105. label = s.get_type()
  106. li, lt = s.token_next(-1, skip_cm=True)
  107. if lt is None:
  108. return None
  109. if lt and lt.ttype == sqlparse.tokens.DDL:
  110. # DDL - build up!
  111. parts = []
  112. first = True
  113. skipping = False
  114. for t in _tokens(s, li):
  115. if not first:
  116. if isinstance(t, sqlparse.sql.Identifier) or isinstance(t, sqlparse.sql.Function):
  117. parts.append(t.normalized)
  118. break
  119. elif t.ttype != sqlparse.tokens.Keyword:
  120. break
  121. first = False
  122. if t.normalized == 'IF':
  123. skipping = True
  124. if not skipping:
  125. parts.append(t.normalized)
  126. label = label + ' ' + ' '.join(parts)
  127. elif label == 'UNKNOWN':
  128. ls = []
  129. for t in _tokens(s):
  130. if t.ttype == sqlparse.tokens.Keyword:
  131. ls.append(t.normalized)
  132. else:
  133. break
  134. if ls:
  135. label = ' '.join(ls)
  136. name = s.get_real_name()
  137. if name:
  138. label += f' {name}'
  139. return label
  140. def is_empty(s):
  141. "check if an SQL statement is empty"
  142. lt = s.token_first(skip_cm=True, skip_ws=True)
  143. return lt is None
  144. class ScriptChunk(NamedTuple):
  145. "A single chunk of an SQL script."
  146. label: str
  147. allowed_errors: List[str]
  148. src: str
  149. use_transaction: bool = True
  150. @property
  151. def statements(self):
  152. return [s for s in sqlparse.parse(self.src) if not is_empty(s)]
  153. class SqlScript:
  154. """
  155. Class for processing & executing SQL scripts with the following features ``psql``
  156. does not have:
  157. * Splitting the script into (named) steps, to commit chunks in transactions
  158. * Recording metadata (currently just dependencies) for the script
  159. * Allowing chunks to fail with specific errors
  160. The last feature is to help with writing _idempotent_ scripts: by allowing a chunk
  161. to fail with a known error (e.g. creating a constraint that already exists), you
  162. can write a script that can run cleanly even if it has already been run.
  163. Args:
  164. file: the path to the SQL script to read.
  165. """
  166. _sep_re = re.compile(r'^---\s*(?P<inst>.*)')
  167. _icode_re = re.compile(r'#(?P<code>\w+)\s*(?P<args>.*\S)?\s*$')
  168. chunks: List[ScriptChunk]
  169. def __init__(self, file):
  170. if hasattr(file, 'read'):
  171. self._parse(peekable(file))
  172. else:
  173. with open(file, 'r', encoding='utf8') as f:
  174. self._parse(peekable(f))
  175. def _parse(self, lines):
  176. self.chunks = []
  177. self.deps, self.tables = self._parse_script_header(lines)
  178. next_chunk = self._parse_chunk(lines, len(self.chunks) + 1)
  179. while next_chunk is not None:
  180. if next_chunk:
  181. self.chunks.append(next_chunk)
  182. next_chunk = self._parse_chunk(lines, len(self.chunks) + 1)
  183. @classmethod
  184. def _parse_script_header(cls, lines):
  185. deps = []
  186. tables = []
  187. line = lines.peek(None)
  188. while line is not None:
  189. hm = cls._sep_re.match(line)
  190. if hm is None:
  191. break
  192. inst = hm.group('inst')
  193. cm = cls._icode_re.match(inst)
  194. if cm is None:
  195. next(lines) # eat line
  196. continue
  197. code = cm.group('code')
  198. args = cm.group('args')
  199. if code == 'dep':
  200. deps.append(args)
  201. next(lines) # eat line
  202. elif code == 'table':
  203. parts = args.split('.', 2)
  204. if len(parts) > 1:
  205. ns, tbl = parts
  206. tables.append((ns, tbl))
  207. else:
  208. tables.append(('public', args))
  209. next(lines) # eat line
  210. else: # any other code, we're out of header
  211. break
  212. line = lines.peek(None)
  213. return deps, tables
  214. @classmethod
  215. def _parse_chunk(cls, lines: peekable, n: int):
  216. qlines = []
  217. chunk = cls._read_header(lines)
  218. qlines = cls._read_query(lines)
  219. # end of file, do we have a chunk?
  220. if qlines:
  221. if chunk.label is None:
  222. chunk = chunk._replace(label=f'Step {n}')
  223. return chunk._replace(src='\n'.join(qlines))
  224. elif qlines is not None:
  225. return False # empty chunk
  226. @classmethod
  227. def _read_header(cls, lines: peekable):
  228. label = None
  229. errs = []
  230. tx = True
  231. line = lines.peek(None)
  232. while line is not None:
  233. hm = cls._sep_re.match(line)
  234. if hm is None:
  235. break
  236. next(lines) # eat line
  237. line = lines.peek(None)
  238. inst = hm.group('inst')
  239. cm = cls._icode_re.match(inst)
  240. if cm is None:
  241. continue
  242. code = cm.group('code')
  243. args = cm.group('args')
  244. if code == 'step':
  245. label = args
  246. elif code == 'allow':
  247. err = getattr(psycopg2.errorcodes, args.upper())
  248. _log.debug('step allows error %s (%s)', args, err)
  249. errs.append(err)
  250. elif code == 'notx':
  251. _log.debug('chunk will run outside a transaction')
  252. tx = False
  253. else:
  254. _log.error('unrecognized query instruction %s', code)
  255. raise ValueError(f'invalid query instruction {code}')
  256. return ScriptChunk(label=label, allowed_errors=errs, src=None,
  257. use_transaction=tx)
  258. @classmethod
  259. def _read_query(cls, lines: peekable):
  260. qls = []
  261. line = lines.peek(None)
  262. while line is not None and not cls._sep_re.match(line):
  263. qls.append(next(lines))
  264. line = lines.peek(None)
  265. # trim lines
  266. while qls and not qls[0].strip():
  267. qls.pop(0)
  268. while qls and not qls[-1].strip():
  269. qls.pop(-1)
  270. if qls or line is not None:
  271. return qls
  272. else:
  273. return None # end of file
  274. def execute(self, dbc, transcript=None):
  275. """
  276. Execute the SQL script.
  277. Args:
  278. dbc: the database connection.
  279. transcript: a file to receive the run transcript.
  280. """
  281. all_st = time.perf_counter()
  282. for step in self.chunks:
  283. start = time.perf_counter()
  284. _log.info('Running ‘%s’', step.label)
  285. if transcript is not None:
  286. print('CHUNK', step.label, file=transcript)
  287. if step.use_transaction:
  288. with dbc, dbc.cursor() as cur:
  289. self._run_step(step, dbc, cur, True, transcript)
  290. else:
  291. ac = dbc.autocommit
  292. try:
  293. dbc.autocommit = True
  294. with dbc.cursor() as cur:
  295. self._run_step(step, dbc, cur, False, transcript)
  296. finally:
  297. dbc.autocommit = ac
  298. elapsed = time.perf_counter() - start
  299. elapsed = timedelta(seconds=elapsed)
  300. print('CHUNK ELAPSED', elapsed, file=transcript)
  301. _log.info('Finished ‘%s’ in %s', step.label, compress_date(elapsed))
  302. elapsed = time.perf_counter() - all_st
  303. elasped = timedelta(seconds=elapsed)
  304. _log.info('Script completed in %s', compress_date(elapsed))
  305. def describe(self):
  306. for dep in self.deps:
  307. _log.info('Dependency ‘%s’', dep)
  308. for step in self.chunks:
  309. _log.info('Chunk ‘%s’', step.label)
  310. for s in step.statements:
  311. _log.info('Statement %s', describe_statement(s))
  312. def _run_step(self, step, dbc, cur, commit, transcript):
  313. try:
  314. for sql in step.statements:
  315. start = time.perf_counter()
  316. _log.debug('Executing %s', describe_statement(sql))
  317. _log.debug('Query: %s', sql)
  318. if transcript is not None:
  319. print('STMT', describe_statement(sql), file=transcript)
  320. cur.execute(str(sql))
  321. elapsed = time.perf_counter() - start
  322. elapsed = timedelta(seconds=elapsed)
  323. rows = cur.rowcount
  324. if transcript is not None:
  325. print('ELAPSED', elapsed, file=transcript)
  326. if rows is not None and rows >= 0:
  327. if transcript is not None:
  328. print('ROWS', rows, file=transcript)
  329. _log.info('finished %s in %s (%d rows)', describe_statement(sql),
  330. compress_date(elapsed), rows)
  331. else:
  332. _log.info('finished %s in %s (%d rows)', describe_statement(sql),
  333. compress_date(elapsed), rows)
  334. if commit:
  335. dbc.commit()
  336. except psycopg2.Error as e:
  337. if e.pgcode in step.allowed_errors:
  338. _log.info('Failed with acceptable error %s (%s)',
  339. e.pgcode, psycopg2.errorcodes.lookup(e.pgcode))
  340. if transcript is not None:
  341. print('ERROR', e.pgcode, psycopg2.errorcodes.lookup(e.pgcode), file=transcript)
  342. else:
  343. _log.error('Error in "%s" %s: %s: %s',
  344. step.label, describe_statement(sql),
  345. psycopg2.errorcodes.lookup(e.pgcode), e)
  346. if e.pgerror:
  347. _log.info('Query diagnostics:\n%s', e.pgerror)
  348. raise e
  349. class _LoadThread(threading.Thread):
  350. """
  351. Thread worker for copying database results to a stream we can read.
  352. """
  353. def __init__(self, dbc, query, dir='out'):
  354. super().__init__()
  355. self.database = dbc
  356. self.query = query
  357. rfd, wfd = os.pipe()
  358. self.reader = os.fdopen(rfd)
  359. self.writer = os.fdopen(wfd, 'w')
  360. self.chan = self.writer if dir == 'out' else self.reader
  361. def run(self):
  362. with self.chan, self.database.cursor() as cur:
  363. cur.copy_expert(self.query, self.chan)
  364. def load_table(dbc, query):
  365. """
  366. Load a query into a Pandas data frame.
  367. This is substantially more efficient than Pandas ``read_sql``, because it directly
  368. streams CSV data from the database instead of going through SQLAlchemy.
  369. """
  370. cq = sql.SQL('COPY ({}) TO STDOUT WITH CSV HEADER')
  371. q = sql.SQL(query)
  372. thread = _LoadThread(dbc, cq.format(q))
  373. thread.start()
  374. data = pd.read_csv(thread.reader)
  375. thread.join()
  376. return data
  377. def save_table(dbc, table, data: pd.DataFrame):
  378. """
  379. Save a table from a Pandas data frame.
  380. This is substantially more efficient than Pandas ``read_sql``, because it directly
  381. streams CSV data from the database instead of going through SQLAlchemy.
  382. """
  383. cq = sql.SQL('COPY {} FROM STDIN WITH CSV')
  384. thread = _LoadThread(dbc, cq.format(table), 'in')
  385. thread.start()
  386. data.to_csv(thread.writer, header=False, index=False)
  387. thread.writer.close()
  388. thread.join()