Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit 06bf559

Browse files
committed
Refactor connect() into Connect class
1 parent 66db85c commit 06bf559

File tree

10 files changed

+191
-147
lines changed

10 files changed

+191
-147
lines changed

data_diff/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import Sequence, Tuple, Iterator, Optional, Union
22

33
from .tracking import disable_tracking
4-
from .sqeleton.databases.connect import connect
4+
from .databases import connect
55
from .sqeleton.databases.database_types import DbKey, DbTime, DbPath
66
from .diff_tables import Algorithm
77
from .hashdiff_tables import HashDiffer, DEFAULT_BISECTION_THRESHOLD, DEFAULT_BISECTION_FACTOR

data_diff/__main__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414
from .hashdiff_tables import HashDiffer, DEFAULT_BISECTION_THRESHOLD, DEFAULT_BISECTION_FACTOR
1515
from .joindiff_tables import TABLE_WRITE_LIMIT, JoinDiffer
1616
from .table_segment import TableSegment
17-
from .databases.database_types import create_schema
18-
from .databases.connect import connect
17+
from .sqeleton.databases.database_types import create_schema
18+
from .databases import connect
1919
from .parse_time import parse_time_before_now, UNITS_STR, ParseError
2020
from .config import apply_config_from_file
2121
from .tracking import disable_tracking

data_diff/databases/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# from data_diff.sqeleton.databases.base import MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, QueryError, ConnectError
1+
from data_diff.sqeleton.databases.base import MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, QueryError, ConnectError
22

33
from .postgresql import PostgreSQL
44
from .mysql import MySQL
@@ -13,4 +13,5 @@
1313
from .vertica import Vertica
1414
from .duckdb import DuckDB
1515

16-
from .connect import connect_to_uri
16+
from ._connect import connect
17+

data_diff/databases/_connect.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
from data_diff.sqeleton.databases.connect import MatchUriPath, Connect
2+
3+
from .postgresql import PostgreSQL
4+
from .mysql import MySQL
5+
from .oracle import Oracle
6+
from .snowflake import Snowflake
7+
from .bigquery import BigQuery
8+
from .redshift import Redshift
9+
from .presto import Presto
10+
from .databricks import Databricks
11+
from .trino import Trino
12+
from .clickhouse import Clickhouse
13+
from .vertica import Vertica
14+
15+
16+
17+
MATCH_URI_PATH = {
18+
"postgresql": MatchUriPath(PostgreSQL, ["database?"], help_str="postgresql://<user>:<pass>@<host>/<database>"),
19+
"mysql": MatchUriPath(MySQL, ["database?"], help_str="mysql://<user>:<pass>@<host>/<database>"),
20+
"oracle": MatchUriPath(Oracle, ["database?"], help_str="oracle://<user>:<pass>@<host>/<database>"),
21+
# "mssql": MatchUriPath(MsSQL, ["database?"], help_str="mssql://<user>:<pass>@<host>/<database>"),
22+
"redshift": MatchUriPath(Redshift, ["database?"], help_str="redshift://<user>:<pass>@<host>/<database>"),
23+
"snowflake": MatchUriPath(
24+
Snowflake,
25+
["database", "schema"],
26+
["warehouse"],
27+
help_str="snowflake://<user>:<pass>@<account>/<database>/<SCHEMA>?warehouse=<WAREHOUSE>",
28+
),
29+
"presto": MatchUriPath(Presto, ["catalog", "schema"], help_str="presto://<user>@<host>/<catalog>/<schema>"),
30+
"bigquery": MatchUriPath(BigQuery, ["dataset"], help_str="bigquery://<project>/<dataset>"),
31+
"databricks": MatchUriPath(
32+
Databricks,
33+
["catalog", "schema"],
34+
help_str="databricks://:access_token@server_name/http_path",
35+
),
36+
"trino": MatchUriPath(Trino, ["catalog", "schema"], help_str="trino://<user>@<host>/<catalog>/<schema>"),
37+
"clickhouse": MatchUriPath(Clickhouse, ["database?"], help_str="clickhouse://<user>:<pass>@<host>/<database>"),
38+
"vertica": MatchUriPath(Vertica, ["database?"], help_str="vertica://<user>:<pass>@<host>/<database>"),
39+
}
40+
41+
connect = Connect(MATCH_URI_PATH)

data_diff/databases/connect.py

Lines changed: 0 additions & 1 deletion
This file was deleted.

data_diff/databases/postgresql.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
from data_diff.sqeleton.databases.postgresql import PostgresqlDialect, PostgreSQL
1+
from data_diff.sqeleton.databases import postgresql
22
from .base import BaseDialect
33

4-
class PostgresqlDialect(BaseDialect, PostgresqlDialect):
4+
class PostgresqlDialect(BaseDialect, postgresql.PostgresqlDialect):
55
pass
66

7-
class PostgreSQL(PostgreSQL):
7+
class PostgreSQL(postgresql.PostgreSQL):
88
dialect = PostgresqlDialect()

data_diff/sqeleton/databases/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,3 @@
1212
from .clickhouse import Clickhouse
1313
from .vertica import Vertica
1414

15-
from .connect import connect_to_uri
Lines changed: 133 additions & 129 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Type, List, Optional, Union
1+
from typing import Type, List, Optional, Union, Dict
22
from itertools import zip_longest
33
import dsnparse
44

@@ -94,134 +94,138 @@ def match_path(self, dsn):
9494
}
9595

9696

97-
def connect_to_uri(db_uri: str, thread_count: Optional[int] = 1) -> Database:
98-
"""Connect to the given database uri
99-
100-
thread_count determines the max number of worker threads per database,
101-
if relevant. None means no limit.
102-
103-
Parameters:
104-
db_uri (str): The URI for the database to connect
105-
thread_count (int, optional): Size of the threadpool. Ignored by cloud databases. (default: 1)
106-
107-
Note: For non-cloud databases, a low thread-pool size may be a performance bottleneck.
108-
109-
Supported schemes:
110-
- postgresql
111-
- mysql
112-
- oracle
113-
- snowflake
114-
- bigquery
115-
- redshift
116-
- presto
117-
- databricks
118-
- trino
119-
- clickhouse
120-
- vertica
121-
"""
122-
123-
dsn = dsnparse.parse(db_uri)
124-
if len(dsn.schemes) > 1:
125-
raise NotImplementedError("No support for multiple schemes")
126-
(scheme,) = dsn.schemes
127-
128-
try:
129-
matcher = MATCH_URI_PATH[scheme]
130-
except KeyError:
131-
raise NotImplementedError(f"Scheme {scheme} currently not supported")
132-
133-
cls = matcher.database_cls
134-
135-
if scheme == "databricks":
136-
assert not dsn.user
137-
kw = {}
138-
kw["access_token"] = dsn.password
139-
kw["http_path"] = dsn.path
140-
kw["server_hostname"] = dsn.host
141-
kw.update(dsn.query)
142-
elif scheme == 'duckdb':
143-
kw = {}
144-
kw['filepath'] = dsn.dbname
145-
kw['dbname'] = dsn.user
146-
else:
147-
kw = matcher.match_path(dsn)
148-
149-
if scheme == "bigquery":
150-
kw["project"] = dsn.host
151-
return cls(**kw)
152-
153-
if scheme == "snowflake":
154-
kw["account"] = dsn.host
155-
assert not dsn.port
156-
kw["user"] = dsn.user
157-
kw["password"] = dsn.password
97+
@dataclass
98+
class Connect:
99+
match_uri_path: Dict[str, MatchUriPath]
100+
101+
def connect_to_uri(self, db_uri: str, thread_count: Optional[int] = 1) -> Database:
102+
"""Connect to the given database uri
103+
104+
thread_count determines the max number of worker threads per database,
105+
if relevant. None means no limit.
106+
107+
Parameters:
108+
db_uri (str): The URI for the database to connect
109+
thread_count (int, optional): Size of the threadpool. Ignored by cloud databases. (default: 1)
110+
111+
Note: For non-cloud databases, a low thread-pool size may be a performance bottleneck.
112+
113+
Supported schemes:
114+
- postgresql
115+
- mysql
116+
- oracle
117+
- snowflake
118+
- bigquery
119+
- redshift
120+
- presto
121+
- databricks
122+
- trino
123+
- clickhouse
124+
- vertica
125+
"""
126+
127+
dsn = dsnparse.parse(db_uri)
128+
if len(dsn.schemes) > 1:
129+
raise NotImplementedError("No support for multiple schemes")
130+
(scheme,) = dsn.schemes
131+
132+
try:
133+
matcher = self.match_uri_path[scheme]
134+
except KeyError:
135+
raise NotImplementedError(f"Scheme {scheme} currently not supported")
136+
137+
cls = matcher.database_cls
138+
139+
if scheme == "databricks":
140+
assert not dsn.user
141+
kw = {}
142+
kw["access_token"] = dsn.password
143+
kw["http_path"] = dsn.path
144+
kw["server_hostname"] = dsn.host
145+
kw.update(dsn.query)
146+
elif scheme == 'duckdb':
147+
kw = {}
148+
kw['filepath'] = dsn.dbname
149+
kw['dbname'] = dsn.user
158150
else:
159-
kw["host"] = dsn.host
160-
kw["port"] = dsn.port
161-
kw["user"] = dsn.user
162-
if dsn.password:
163-
kw["password"] = dsn.password
164-
165-
kw = {k: v for k, v in kw.items() if v is not None}
166-
167-
if issubclass(cls, ThreadedDatabase):
168-
return cls(thread_count=thread_count, **kw)
169-
170-
return cls(**kw)
171-
172-
173-
def connect_with_dict(d, thread_count):
174-
d = dict(d)
175-
driver = d.pop("driver")
176-
try:
177-
matcher = MATCH_URI_PATH[driver]
178-
except KeyError:
179-
raise NotImplementedError(f"Driver {driver} currently not supported")
180-
181-
cls = matcher.database_cls
182-
if issubclass(cls, ThreadedDatabase):
183-
return cls(thread_count=thread_count, **d)
184-
185-
return cls(**d)
151+
kw = matcher.match_path(dsn)
186152

153+
if scheme == "bigquery":
154+
kw["project"] = dsn.host
155+
return cls(**kw)
187156

188-
def connect(db_conf: Union[str, dict], thread_count: Optional[int] = 1) -> Database:
189-
"""Connect to a database using the given database configuration.
190-
191-
Configuration can be given either as a URI string, or as a dict of {option: value}.
192-
193-
The dictionary configuration uses the same keys as the TOML 'database' definition given with --conf.
194-
195-
thread_count determines the max number of worker threads per database,
196-
if relevant. None means no limit.
197-
198-
Parameters:
199-
db_conf (str | dict): The configuration for the database to connect. URI or dict.
200-
thread_count (int, optional): Size of the threadpool. Ignored by cloud databases. (default: 1)
201-
202-
Note: For non-cloud databases, a low thread-pool size may be a performance bottleneck.
203-
204-
Supported drivers:
205-
- postgresql
206-
- mysql
207-
- oracle
208-
- snowflake
209-
- bigquery
210-
- redshift
211-
- presto
212-
- databricks
213-
- trino
214-
- clickhouse
215-
- vertica
216-
217-
Example:
218-
>>> connect("mysql://localhost/db")
219-
<data_diff.databases.mysql.MySQL object at 0x0000025DB45F4190>
220-
>>> connect({"driver": "mysql", "host": "localhost", "database": "db"})
221-
<data_diff.databases.mysql.MySQL object at 0x0000025DB3F94820>
222-
"""
223-
if isinstance(db_conf, str):
224-
return connect_to_uri(db_conf, thread_count)
225-
elif isinstance(db_conf, dict):
226-
return connect_with_dict(db_conf, thread_count)
227-
raise TypeError(f"db configuration must be a URI string or a dictionary. Instead got '{db_conf}'.")
157+
if scheme == "snowflake":
158+
kw["account"] = dsn.host
159+
assert not dsn.port
160+
kw["user"] = dsn.user
161+
kw["password"] = dsn.password
162+
else:
163+
kw["host"] = dsn.host
164+
kw["port"] = dsn.port
165+
kw["user"] = dsn.user
166+
if dsn.password:
167+
kw["password"] = dsn.password
168+
169+
kw = {k: v for k, v in kw.items() if v is not None}
170+
171+
if issubclass(cls, ThreadedDatabase):
172+
return cls(thread_count=thread_count, **kw)
173+
174+
return cls(**kw)
175+
176+
177+
def connect_with_dict(self, d, thread_count):
178+
d = dict(d)
179+
driver = d.pop("driver")
180+
try:
181+
matcher = self.match_uri_path[driver]
182+
except KeyError:
183+
raise NotImplementedError(f"Driver {driver} currently not supported")
184+
185+
cls = matcher.database_cls
186+
if issubclass(cls, ThreadedDatabase):
187+
return cls(thread_count=thread_count, **d)
188+
189+
return cls(**d)
190+
191+
192+
def __call__(self, db_conf: Union[str, dict], thread_count: Optional[int] = 1) -> Database:
193+
"""Connect to a database using the given database configuration.
194+
195+
Configuration can be given either as a URI string, or as a dict of {option: value}.
196+
197+
The dictionary configuration uses the same keys as the TOML 'database' definition given with --conf.
198+
199+
thread_count determines the max number of worker threads per database,
200+
if relevant. None means no limit.
201+
202+
Parameters:
203+
db_conf (str | dict): The configuration for the database to connect. URI or dict.
204+
thread_count (int, optional): Size of the threadpool. Ignored by cloud databases. (default: 1)
205+
206+
Note: For non-cloud databases, a low thread-pool size may be a performance bottleneck.
207+
208+
Supported drivers:
209+
- postgresql
210+
- mysql
211+
- oracle
212+
- snowflake
213+
- bigquery
214+
- redshift
215+
- presto
216+
- databricks
217+
- trino
218+
- clickhouse
219+
- vertica
220+
221+
Example:
222+
>>> connect("mysql://localhost/db")
223+
<data_diff.databases.mysql.MySQL object at 0x0000025DB45F4190>
224+
>>> connect({"driver": "mysql", "host": "localhost", "database": "db"})
225+
<data_diff.databases.mysql.MySQL object at 0x0000025DB3F94820>
226+
"""
227+
if isinstance(db_conf, str):
228+
return self.connect_to_uri(db_conf, thread_count)
229+
elif isinstance(db_conf, dict):
230+
return self.connect_with_dict(db_conf, thread_count)
231+
raise TypeError(f"db configuration must be a URI string or a dictionary. Instead got '{db_conf}'.")

tests/test_database.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
import unittest
22

33
from .common import str_to_checksum, TEST_MYSQL_CONN_STRING
4-
from data_diff.databases import connect_to_uri
4+
from data_diff.databases import connect
55

66

77
class TestDatabase(unittest.TestCase):
88
def setUp(self):
9-
self.mysql = connect_to_uri(TEST_MYSQL_CONN_STRING)
9+
self.mysql = connect(TEST_MYSQL_CONN_STRING)
1010

1111
def test_connect_to_db(self):
1212
self.assertEqual(1, self.mysql.query("SELECT 1", int))
@@ -21,9 +21,9 @@ def test_md5_as_int(self):
2121

2222
class TestConnect(unittest.TestCase):
2323
def test_bad_uris(self):
24-
self.assertRaises(ValueError, connect_to_uri, "p")
25-
self.assertRaises(ValueError, connect_to_uri, "postgresql:///bla/foo")
26-
self.assertRaises(ValueError, connect_to_uri, "snowflake://user:pass@bya42734/xdiffdev/TEST1")
24+
self.assertRaises(ValueError, connect, "p")
25+
self.assertRaises(ValueError, connect, "postgresql:///bla/foo")
26+
self.assertRaises(ValueError, connect, "snowflake://user:pass@bya42734/xdiffdev/TEST1")
2727
self.assertRaises(
28-
ValueError, connect_to_uri, "snowflake://user:pass@bya42734/xdiffdev/TEST1?warehouse=ha&schema=dup"
28+
ValueError, connect, "snowflake://user:pass@bya42734/xdiffdev/TEST1?warehouse=ha&schema=dup"
2929
)

0 commit comments

Comments
 (0)