db.rs 7.0 KB


  1. use std::io::prelude::*;
  2. use log::*;
  3. use anyhow::{anyhow, Result};
  4. use os_pipe::{pipe, PipeWriter};
  5. use postgres::{TlsMode};
  6. use structopt::StructOpt;
  7. pub use postgres::Connection;
  8. use std::thread;
  9. pub trait ConnectInfo {
  10. fn db_url(&self) -> Result<String>;
  11. }
  12. impl ConnectInfo for String {
  13. fn db_url(&self) -> Result<String> {
  14. Ok(self.clone())
  15. }
  16. }
  17. impl ConnectInfo for Option<String> {
  18. fn db_url(&self) -> Result<String> {
  19. match self {
  20. Some(ref s) => Ok(s.clone()),
  21. None => Err(anyhow!("no URL provided"))
  22. }
  23. }
  24. }
  25. /// Database options
  26. #[derive(StructOpt, Debug, Clone)]
  27. pub struct DbOpts {
  28. /// Database URL to connect to
  29. #[structopt(long="db-url")]
  30. db_url: Option<String>,
  31. /// Database schema
  32. #[structopt(long="db-schema")]
  33. db_schema: Option<String>
  34. }
  35. impl DbOpts {
  36. /// Open the database connection
  37. pub fn open(&self) -> Result<Connection> {
  38. let url = self.url()?;
  39. connect(&url)
  40. }
  41. pub fn url<'a>(&'a self) -> Result<String> {
  42. Ok(match self.db_url {
  43. Some(ref s) => s.clone(),
  44. None => std::env::var("DB_URL")?
  45. })
  46. }
  47. /// Get the DB schema
  48. pub fn schema<'a>(&'a self) -> &'a str {
  49. match self.db_schema {
  50. Some(ref s) => s,
  51. None => "public"
  52. }
  53. }
  54. /// Change the default schema
  55. pub fn default_schema(self, default: &str) -> DbOpts {
  56. DbOpts {
  57. db_url: self.db_url,
  58. db_schema: self.db_schema.or_else(|| Some(default.to_string()))
  59. }
  60. }
  61. }
  62. impl ConnectInfo for DbOpts {
  63. fn db_url(&self) -> Result<String> {
  64. self.url()
  65. }
  66. }
  67. pub fn connect(url: &str) -> Result<Connection> {
  68. Ok(Connection::connect(url, TlsMode::None)?)
  69. }
  70. pub struct CopyRequest {
  71. db_url: String,
  72. schema: Option<String>,
  73. table: String,
  74. columns: Option<Vec<String>>,
  75. format: Option<String>,
  76. truncate: bool,
  77. name: String
  78. }
  79. impl CopyRequest {
  80. pub fn new<C: ConnectInfo>(db: &C, table: &str) -> Result<CopyRequest> {
  81. Ok(CopyRequest {
  82. db_url: db.db_url()?,
  83. schema: None,
  84. table: table.to_string(),
  85. columns: None,
  86. format: None,
  87. truncate: false,
  88. name: "copy".to_string()
  89. })
  90. }
  91. pub fn with_schema(self, schema: &str) -> CopyRequest {
  92. CopyRequest {
  93. schema: Some(schema.to_string()),
  94. ..self
  95. }
  96. }
  97. pub fn with_columns(self, columns: &[&str]) -> CopyRequest {
  98. let mut cvec = Vec::with_capacity(columns.len());
  99. for c in columns {
  100. cvec.push(c.to_string());
  101. }
  102. CopyRequest {
  103. columns: Some(cvec),
  104. ..self
  105. }
  106. }
  107. pub fn with_format(self, format: &str) -> CopyRequest {
  108. CopyRequest {
  109. format: Some(format.to_string()),
  110. ..self
  111. }
  112. }
  113. pub fn with_name(self, name: &str) -> CopyRequest {
  114. CopyRequest {
  115. name: name.to_string(),
  116. ..self
  117. }
  118. }
  119. pub fn truncate(self, trunc: bool) -> CopyRequest {
  120. CopyRequest {
  121. truncate: trunc,
  122. ..self
  123. }
  124. }
  125. pub fn table(&self) -> String {
  126. match self.schema {
  127. Some(ref s) => format!("{}.{}", s, self.table),
  128. None => self.table.clone()
  129. }
  130. }
  131. fn query(&self) -> String {
  132. let mut query = format!("COPY {}", self.table());
  133. if let Some(ref cs) = self.columns {
  134. let s = format!(" ({})", cs.join(", "));
  135. query.push_str(&s);
  136. }
  137. query.push_str(" FROM STDIN");
  138. if let Some(ref fmt) = self.format {
  139. query.push_str(&format!(" (FORMAT {})", fmt));
  140. }
  141. query
  142. }
  143. /// Open a writer for a copy request
  144. pub fn open(self) -> Result<CopyTarget> {
  145. let query = self.query();
  146. let (mut reader, writer) = pipe()?;
  147. let name = self.name.clone();
  148. let tb = thread::Builder::new().name(name.clone());
  149. let jh = tb.spawn(move || {
  150. let query = query;
  151. let db = connect(&self.db_url).unwrap();
  152. let mut cfg = postgres::transaction::Config::new();
  153. cfg.isolation_level(postgres::transaction::IsolationLevel::ReadUncommitted);
  154. let tx = db.transaction_with(&cfg).unwrap();
  155. if self.truncate {
  156. let tq = format!("TRUNCATE {}", self.table());
  157. info!("running {}", tq);
  158. tx.execute(&tq, &[]).unwrap();
  159. }
  160. info!("preparing {}", query);
  161. let stmt = tx.prepare(&query).unwrap();
  162. let n = stmt.copy_in(&[], &mut reader).unwrap();
  163. info!("committing copy");
  164. tx.commit().unwrap();
  165. n
  166. })?;
  167. Ok(CopyTarget {
  168. writer: Some(writer),
  169. name: name,
  170. thread: Some(jh)
  171. })
  172. }
  173. }
  174. /// Writer for copy-in operations
  175. ///
  176. /// This writer writes to the copy-in for PostgreSQL. It is unbuffered; you usually
  177. /// want to wrap it in a `BufWriter`.
  178. pub struct CopyTarget {
  179. writer: Option<PipeWriter>,
  180. name: String,
  181. thread: Option<thread::JoinHandle<u64>>
  182. }
  183. impl CopyTarget {
  184. fn do_close(&mut self, warn: bool) -> Result<u64> {
  185. if let Some(w) = self.writer.take() {
  186. std::mem::drop(w);
  187. }
  188. if let Some(thread) = self.thread.take() {
  189. match thread.join() {
  190. Ok(n) => {
  191. info!("{}: wrote {} lines", self.name, n);
  192. Ok(n)
  193. }
  194. Err(e) => {
  195. error!("{}: error: {:?}", self.name, e);
  196. Err(anyhow!("worker thread failed"))
  197. }
  198. }
  199. } else {
  200. if warn {
  201. error!("{} already shut down", self.name);
  202. } else {
  203. debug!("{} already shut down", self.name);
  204. }
  205. Ok(0)
  206. }
  207. }
  208. }
  209. impl Write for CopyTarget {
  210. fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
  211. self.writer.as_ref().expect("writer missing").write(buf)
  212. }
  213. fn flush(&mut self) -> std::io::Result<()> {
  214. self.writer.as_ref().expect("writer missing").flush()
  215. }
  216. }
  217. impl Drop for CopyTarget {
  218. fn drop(&mut self) {
  219. self.do_close(false).unwrap();
  220. }
  221. }
  222. #[test]
  223. fn cr_initial_correct() {
  224. let cr = CopyRequest::new(&("foo".to_string()), "wombat").unwrap();
  225. assert_eq!(cr.name, "copy");
  226. assert_eq!(cr.db_url, "foo");
  227. assert_eq!(cr.table, "wombat");
  228. assert!(cr.columns.is_none());
  229. assert!(cr.schema.is_none());
  230. assert!(!cr.truncate);
  231. assert_eq!(cr.query(), "COPY wombat FROM STDIN");
  232. }
  233. #[test]
  234. fn cr_set_name() {
  235. let cr = CopyRequest::new(&("foo".to_string()), "wombat").unwrap();
  236. let cr = cr.with_name("bob");
  237. assert_eq!(cr.name, "bob");
  238. assert_eq!(cr.db_url, "foo");
  239. assert_eq!(cr.table, "wombat");
  240. assert!(cr.columns.is_none());
  241. assert!(cr.schema.is_none());
  242. assert!(!cr.truncate);
  243. }
  244. #[test]
  245. fn cr_set_format() {
  246. let cr = CopyRequest::new(&("foo".to_string()), "wombat").unwrap();
  247. let cr = cr.with_format("CSV");
  248. assert_eq!(cr.format, Some("CSV".to_string()));
  249. assert_eq!(cr.db_url, "foo");
  250. assert_eq!(cr.table, "wombat");
  251. assert!(cr.columns.is_none());
  252. assert!(cr.schema.is_none());
  253. assert!(!cr.truncate);
  254. }
  255. #[test]
  256. fn cr_schema_propagated() {
  257. let cr = CopyRequest::new(&("foo".to_string()), "wombat").unwrap();
  258. let cr = cr.with_schema("pizza");
  259. assert_eq!(cr.name, "copy");
  260. assert_eq!(cr.db_url, "foo");
  261. assert_eq!(cr.table, "wombat");
  262. assert!(cr.columns.is_none());
  263. assert_eq!(cr.schema.as_ref().expect("no schema"), "pizza");
  264. assert!(!cr.truncate);
  265. assert_eq!(cr.query(), "COPY pizza.wombat FROM STDIN");
  266. }