|
1 | | -from typing import Type, List, Optional, Union |
| 1 | +from typing import Type, List, Optional, Union, Dict |
2 | 2 | from itertools import zip_longest |
3 | 3 | import dsnparse |
4 | 4 |
|
@@ -94,134 +94,138 @@ def match_path(self, dsn): |
94 | 94 | } |
95 | 95 |
|
96 | 96 |
|
97 | | -def connect_to_uri(db_uri: str, thread_count: Optional[int] = 1) -> Database: |
98 | | - """Connect to the given database uri |
99 | | -
|
100 | | - thread_count determines the max number of worker threads per database, |
101 | | - if relevant. None means no limit. |
102 | | -
|
103 | | - Parameters: |
104 | | - db_uri (str): The URI for the database to connect |
105 | | - thread_count (int, optional): Size of the threadpool. Ignored by cloud databases. (default: 1) |
106 | | -
|
107 | | - Note: For non-cloud databases, a low thread-pool size may be a performance bottleneck. |
108 | | -
|
109 | | - Supported schemes: |
110 | | - - postgresql |
111 | | - - mysql |
112 | | - - oracle |
113 | | - - snowflake |
114 | | - - bigquery |
115 | | - - redshift |
116 | | - - presto |
117 | | - - databricks |
118 | | - - trino |
119 | | - - clickhouse |
120 | | - - vertica |
121 | | - """ |
122 | | - |
123 | | - dsn = dsnparse.parse(db_uri) |
124 | | - if len(dsn.schemes) > 1: |
125 | | - raise NotImplementedError("No support for multiple schemes") |
126 | | - (scheme,) = dsn.schemes |
127 | | - |
128 | | - try: |
129 | | - matcher = MATCH_URI_PATH[scheme] |
130 | | - except KeyError: |
131 | | - raise NotImplementedError(f"Scheme {scheme} currently not supported") |
132 | | - |
133 | | - cls = matcher.database_cls |
134 | | - |
135 | | - if scheme == "databricks": |
136 | | - assert not dsn.user |
137 | | - kw = {} |
138 | | - kw["access_token"] = dsn.password |
139 | | - kw["http_path"] = dsn.path |
140 | | - kw["server_hostname"] = dsn.host |
141 | | - kw.update(dsn.query) |
142 | | - elif scheme == 'duckdb': |
143 | | - kw = {} |
144 | | - kw['filepath'] = dsn.dbname |
145 | | - kw['dbname'] = dsn.user |
146 | | - else: |
147 | | - kw = matcher.match_path(dsn) |
148 | | - |
149 | | - if scheme == "bigquery": |
150 | | - kw["project"] = dsn.host |
151 | | - return cls(**kw) |
152 | | - |
153 | | - if scheme == "snowflake": |
154 | | - kw["account"] = dsn.host |
155 | | - assert not dsn.port |
156 | | - kw["user"] = dsn.user |
157 | | - kw["password"] = dsn.password |
| 97 | +@dataclass |
| 98 | +class Connect: |
| 99 | + match_uri_path: Dict[str, MatchUriPath] |
| 100 | + |
| 101 | + def connect_to_uri(self, db_uri: str, thread_count: Optional[int] = 1) -> Database: |
| 102 | + """Connect to the given database uri |
| 103 | +
|
| 104 | + thread_count determines the max number of worker threads per database, |
| 105 | + if relevant. None means no limit. |
| 106 | +
|
| 107 | + Parameters: |
| 108 | + db_uri (str): The URI for the database to connect |
| 109 | + thread_count (int, optional): Size of the threadpool. Ignored by cloud databases. (default: 1) |
| 110 | +
|
| 111 | + Note: For non-cloud databases, a low thread-pool size may be a performance bottleneck. |
| 112 | +
|
| 113 | + Supported schemes: |
| 114 | + - postgresql |
| 115 | + - mysql |
| 116 | + - oracle |
| 117 | + - snowflake |
| 118 | + - bigquery |
| 119 | + - redshift |
| 120 | + - presto |
| 121 | + - databricks |
| 122 | + - trino |
| 123 | + - clickhouse |
| 124 | + - vertica |
| 125 | + """ |
| 126 | + |
| 127 | + dsn = dsnparse.parse(db_uri) |
| 128 | + if len(dsn.schemes) > 1: |
| 129 | + raise NotImplementedError("No support for multiple schemes") |
| 130 | + (scheme,) = dsn.schemes |
| 131 | + |
| 132 | + try: |
| 133 | + matcher = self.match_uri_path[scheme] |
| 134 | + except KeyError: |
| 135 | + raise NotImplementedError(f"Scheme {scheme} currently not supported") |
| 136 | + |
| 137 | + cls = matcher.database_cls |
| 138 | + |
| 139 | + if scheme == "databricks": |
| 140 | + assert not dsn.user |
| 141 | + kw = {} |
| 142 | + kw["access_token"] = dsn.password |
| 143 | + kw["http_path"] = dsn.path |
| 144 | + kw["server_hostname"] = dsn.host |
| 145 | + kw.update(dsn.query) |
| 146 | + elif scheme == 'duckdb': |
| 147 | + kw = {} |
| 148 | + kw['filepath'] = dsn.dbname |
| 149 | + kw['dbname'] = dsn.user |
158 | 150 | else: |
159 | | - kw["host"] = dsn.host |
160 | | - kw["port"] = dsn.port |
161 | | - kw["user"] = dsn.user |
162 | | - if dsn.password: |
163 | | - kw["password"] = dsn.password |
164 | | - |
165 | | - kw = {k: v for k, v in kw.items() if v is not None} |
166 | | - |
167 | | - if issubclass(cls, ThreadedDatabase): |
168 | | - return cls(thread_count=thread_count, **kw) |
169 | | - |
170 | | - return cls(**kw) |
171 | | - |
172 | | - |
173 | | -def connect_with_dict(d, thread_count): |
174 | | - d = dict(d) |
175 | | - driver = d.pop("driver") |
176 | | - try: |
177 | | - matcher = MATCH_URI_PATH[driver] |
178 | | - except KeyError: |
179 | | - raise NotImplementedError(f"Driver {driver} currently not supported") |
180 | | - |
181 | | - cls = matcher.database_cls |
182 | | - if issubclass(cls, ThreadedDatabase): |
183 | | - return cls(thread_count=thread_count, **d) |
184 | | - |
185 | | - return cls(**d) |
| 151 | + kw = matcher.match_path(dsn) |
186 | 152 |
|
| 153 | + if scheme == "bigquery": |
| 154 | + kw["project"] = dsn.host |
| 155 | + return cls(**kw) |
187 | 156 |
|
188 | | -def connect(db_conf: Union[str, dict], thread_count: Optional[int] = 1) -> Database: |
189 | | - """Connect to a database using the given database configuration. |
190 | | -
|
191 | | - Configuration can be given either as a URI string, or as a dict of {option: value}. |
192 | | -
|
193 | | - The dictionary configuration uses the same keys as the TOML 'database' definition given with --conf. |
194 | | -
|
195 | | - thread_count determines the max number of worker threads per database, |
196 | | - if relevant. None means no limit. |
197 | | -
|
198 | | - Parameters: |
199 | | - db_conf (str | dict): The configuration for the database to connect. URI or dict. |
200 | | - thread_count (int, optional): Size of the threadpool. Ignored by cloud databases. (default: 1) |
201 | | -
|
202 | | - Note: For non-cloud databases, a low thread-pool size may be a performance bottleneck. |
203 | | -
|
204 | | - Supported drivers: |
205 | | - - postgresql |
206 | | - - mysql |
207 | | - - oracle |
208 | | - - snowflake |
209 | | - - bigquery |
210 | | - - redshift |
211 | | - - presto |
212 | | - - databricks |
213 | | - - trino |
214 | | - - clickhouse |
215 | | - - vertica |
216 | | -
|
217 | | - Example: |
218 | | - >>> connect("mysql://localhost/db") |
219 | | - <data_diff.databases.mysql.MySQL object at 0x0000025DB45F4190> |
220 | | - >>> connect({"driver": "mysql", "host": "localhost", "database": "db"}) |
221 | | - <data_diff.databases.mysql.MySQL object at 0x0000025DB3F94820> |
222 | | - """ |
223 | | - if isinstance(db_conf, str): |
224 | | - return connect_to_uri(db_conf, thread_count) |
225 | | - elif isinstance(db_conf, dict): |
226 | | - return connect_with_dict(db_conf, thread_count) |
227 | | - raise TypeError(f"db configuration must be a URI string or a dictionary. Instead got '{db_conf}'.") |
| 157 | + if scheme == "snowflake": |
| 158 | + kw["account"] = dsn.host |
| 159 | + assert not dsn.port |
| 160 | + kw["user"] = dsn.user |
| 161 | + kw["password"] = dsn.password |
| 162 | + else: |
| 163 | + kw["host"] = dsn.host |
| 164 | + kw["port"] = dsn.port |
| 165 | + kw["user"] = dsn.user |
| 166 | + if dsn.password: |
| 167 | + kw["password"] = dsn.password |
| 168 | + |
| 169 | + kw = {k: v for k, v in kw.items() if v is not None} |
| 170 | + |
| 171 | + if issubclass(cls, ThreadedDatabase): |
| 172 | + return cls(thread_count=thread_count, **kw) |
| 173 | + |
| 174 | + return cls(**kw) |
| 175 | + |
| 176 | + |
| 177 | + def connect_with_dict(self, d, thread_count): |
| 178 | + d = dict(d) |
| 179 | + driver = d.pop("driver") |
| 180 | + try: |
| 181 | + matcher = self.match_uri_path[driver] |
| 182 | + except KeyError: |
| 183 | + raise NotImplementedError(f"Driver {driver} currently not supported") |
| 184 | + |
| 185 | + cls = matcher.database_cls |
| 186 | + if issubclass(cls, ThreadedDatabase): |
| 187 | + return cls(thread_count=thread_count, **d) |
| 188 | + |
| 189 | + return cls(**d) |
| 190 | + |
| 191 | + |
| 192 | + def __call__(self, db_conf: Union[str, dict], thread_count: Optional[int] = 1) -> Database: |
| 193 | + """Connect to a database using the given database configuration. |
| 194 | +
|
| 195 | + Configuration can be given either as a URI string, or as a dict of {option: value}. |
| 196 | +
|
| 197 | + The dictionary configuration uses the same keys as the TOML 'database' definition given with --conf. |
| 198 | +
|
| 199 | + thread_count determines the max number of worker threads per database, |
| 200 | + if relevant. None means no limit. |
| 201 | +
|
| 202 | + Parameters: |
| 203 | + db_conf (str | dict): The configuration for the database to connect. URI or dict. |
| 204 | + thread_count (int, optional): Size of the threadpool. Ignored by cloud databases. (default: 1) |
| 205 | +
|
| 206 | + Note: For non-cloud databases, a low thread-pool size may be a performance bottleneck. |
| 207 | +
|
| 208 | + Supported drivers: |
| 209 | + - postgresql |
| 210 | + - mysql |
| 211 | + - oracle |
| 212 | + - snowflake |
| 213 | + - bigquery |
| 214 | + - redshift |
| 215 | + - presto |
| 216 | + - databricks |
| 217 | + - trino |
| 218 | + - clickhouse |
| 219 | + - vertica |
| 220 | +
|
| 221 | + Example: |
| 222 | + >>> connect("mysql://localhost/db") |
| 223 | + <data_diff.databases.mysql.MySQL object at 0x0000025DB45F4190> |
| 224 | + >>> connect({"driver": "mysql", "host": "localhost", "database": "db"}) |
| 225 | + <data_diff.databases.mysql.MySQL object at 0x0000025DB3F94820> |
| 226 | + """ |
| 227 | + if isinstance(db_conf, str): |
| 228 | + return self.connect_to_uri(db_conf, thread_count) |
| 229 | + elif isinstance(db_conf, dict): |
| 230 | + return self.connect_with_dict(db_conf, thread_count) |
| 231 | + raise TypeError(f"db configuration must be a URI string or a dictionary. Instead got '{db_conf}'.") |
0 commit comments