99from dataclasses import dataclass
1010from packaging .version import parse as parse_version
1111from typing import List , Optional , Dict , Tuple , Set
12- from .utils import getLogger
12+ from .utils import dbt_diff_string_template , getLogger
1313from .version import __version__
1414from pathlib import Path
1515
@@ -69,16 +69,16 @@ class DiffVars:
6969 dev_path : List [str ]
7070 prod_path : List [str ]
7171 primary_keys : List [str ]
72- datasource_id : str
7372 connection : Dict [str , str ]
7473 threads : Optional [int ]
7574
7675
7776def dbt_diff (
7877 profiles_dir_override : Optional [str ] = None , project_dir_override : Optional [str ] = None , is_cloud : bool = False
7978) -> None :
79+ diff_threads = []
8080 set_entrypoint_name ("CLI-dbt" )
81- dbt_parser = DbtParser (profiles_dir_override , project_dir_override , is_cloud )
81+ dbt_parser = DbtParser (profiles_dir_override , project_dir_override )
8282 models = dbt_parser .get_models ()
8383 datadiff_variables = dbt_parser .get_datadiff_variables ()
8484 config_prod_database = datadiff_variables .get ("prod_database" )
@@ -89,7 +89,17 @@ def dbt_diff(
8989 custom_schemas = True if custom_schemas is None else custom_schemas
9090 set_dbt_user_id (dbt_parser .dbt_user_id )
9191
92- if not is_cloud :
92+ if is_cloud :
93+ if datasource_id is None :
94+ raise ValueError (
95+ "Datasource ID not found, include it as a dbt variable in the dbt_project.yml. \n vars:\n data_diff:\n datasource_id: 1234"
96+ )
97+ datafold_host , url , api_key = _setup_cloud_diff ()
98+
99+ # exit so the user can set the key
100+ if not api_key :
101+ return
102+ else :
93103 dbt_parser .set_connection ()
94104
95105 if config_prod_database is None :
@@ -98,14 +108,14 @@ def dbt_diff(
98108 )
99109
100110 for model in models :
101- diff_vars = _get_diff_vars (
102- dbt_parser , config_prod_database , config_prod_schema , model , datasource_id , custom_schemas
103- )
104-
105- if is_cloud and len ( diff_vars . primary_keys ) > 0 :
106- _cloud_diff ( diff_vars )
107- elif not is_cloud and len ( diff_vars . primary_keys ) > 0 :
108- _local_diff (diff_vars )
111+ diff_vars = _get_diff_vars (dbt_parser , config_prod_database , config_prod_schema , model , custom_schemas )
112+
113+ if diff_vars . primary_keys :
114+ if is_cloud :
115+ diff_thread = run_as_daemon ( _cloud_diff , diff_vars , datasource_id , datafold_host , url , api_key )
116+ diff_threads . append ( diff_thread )
117+ else :
118+ _local_diff (diff_vars )
109119 else :
110120 rich .print (
111121 "[red]"
@@ -116,6 +126,11 @@ def dbt_diff(
116126 + "Skipped due to unknown primary key. Add uniqueness tests, meta, or tags.\n "
117127 )
118128
129+ # wait for all threads
130+ if diff_threads :
131+ for thread in diff_threads :
132+ thread .join ()
133+
119134 rich .print ("Diffs Complete!" )
120135
121136
@@ -124,7 +139,6 @@ def _get_diff_vars(
124139 config_prod_database : Optional [str ],
125140 config_prod_schema : Optional [str ],
126141 model ,
127- datasource_id : int ,
128142 custom_schemas : bool ,
129143) -> DiffVars :
130144 dev_database = model .database
@@ -149,9 +163,7 @@ def _get_diff_vars(
149163 dev_qualified_list = [dev_database , dev_schema , model .alias ]
150164 prod_qualified_list = [prod_database , prod_schema , model .alias ]
151165
152- return DiffVars (
153- dev_qualified_list , prod_qualified_list , primary_keys , datasource_id , dbt_parser .connection , dbt_parser .threads
154- )
166+ return DiffVars (dev_qualified_list , prod_qualified_list , primary_keys , dbt_parser .connection , dbt_parser .threads )
155167
156168
157169def _local_diff (diff_vars : DiffVars ) -> None :
@@ -221,33 +233,10 @@ def _local_diff(diff_vars: DiffVars) -> None:
221233 )
222234
223235
224- def _cloud_diff (diff_vars : DiffVars ) -> None :
225- datafold_host = os .environ .get ("DATAFOLD_HOST" )
226- if datafold_host is None :
227- datafold_host = "https://app.datafold.com"
228- datafold_host = datafold_host .rstrip ("/" )
229- rich .print (f"Cloud datafold host: { datafold_host } " )
230-
231- api_key = os .environ .get ("DATAFOLD_API_KEY" )
232- if not api_key :
233- rich .print ("[red]API key not found, add it as an environment variable called DATAFOLD_API_KEY." )
234- yes_or_no = Confirm .ask ("Would you like to generate a new API key?" )
235- if yes_or_no :
236- webbrowser .open (f"{ datafold_host } /login?next={ datafold_host } /users/me" )
237- return
238- else :
239- raise ValueError ("Cannot diff because the API key is not provided" )
240-
241- if diff_vars .datasource_id is None :
242- raise ValueError (
243- "Datasource ID not found, include it as a dbt variable in the dbt_project.yml. \n vars:\n data_diff:\n datasource_id: 1234"
244- )
245-
246- url = f"{ datafold_host } /api/v1/datadiffs"
247-
236+ def _cloud_diff (diff_vars : DiffVars , datasource_id : int , datafold_host : str , url : str , api_key : str ) -> None :
248237 payload = {
249- "data_source1_id" : diff_vars . datasource_id ,
250- "data_source2_id" : diff_vars . datasource_id ,
238+ "data_source1_id" : datasource_id ,
239+ "data_source2_id" : datasource_id ,
251240 "table1" : diff_vars .prod_path ,
252241 "table2" : diff_vars .dev_path ,
253242 "pk_columns" : diff_vars .primary_keys ,
@@ -258,27 +247,59 @@ def _cloud_diff(diff_vars: DiffVars) -> None:
258247 "Content-Type" : "application/json" ,
259248 }
260249 if is_tracking_enabled ():
261- event_json = create_start_event_json ({"is_cloud" : True , "datasource_id" : diff_vars . datasource_id })
250+ event_json = create_start_event_json ({"is_cloud" : True , "datasource_id" : datasource_id })
262251 run_as_daemon (send_event_json , event_json )
263252
264253 start = time .monotonic ()
265254 error = None
266255 diff_id = None
267256 try :
268- response = requests . request ( "POST" , url , headers = headers , json = payload , timeout = 30 )
269- response . raise_for_status ()
270- data = response . json ( )
271- diff_id = data [ "id" ]
257+ diff_id = _cloud_submit_diff ( url , payload , headers )
258+ summary_url = f" { url } / { diff_id } /summary_results"
259+ diff_results = _cloud_poll_and_get_summary_results ( summary_url , headers )
260+
272261 diff_url = f"{ datafold_host } /datadiffs/{ diff_id } /overview"
273- rich .print (
274- "[red]"
275- + "." .join (diff_vars .prod_path )
276- + " <> "
277- + "." .join (diff_vars .dev_path )
278- + "[/] \n Diff in progress: \n "
279- + diff_url
280- + "\n "
281- )
262+
263+ rows_added_count = diff_results ["pks" ]["exclusives" ][1 ]
264+ rows_removed_count = diff_results ["pks" ]["exclusives" ][0 ]
265+
266+ rows_updated = diff_results ["values" ]["rows_with_differences" ]
267+ total_rows = diff_results ["values" ]["total_rows" ]
268+ rows_unchanged = int (total_rows ) - int (rows_updated )
269+ diff_percent_list = {
270+ x ["column_name" ]: str (x ["match" ]) + "%"
271+ for x in diff_results ["values" ]["columns_diff_stats" ]
272+ if x ["match" ] != 100.0
273+ }
274+
275+ if any ([rows_added_count , rows_removed_count , rows_updated ]):
276+ diff_output = dbt_diff_string_template (
277+ rows_added_count ,
278+ rows_removed_count ,
279+ rows_updated ,
280+ str (rows_unchanged ),
281+ diff_percent_list ,
282+ "Value Match Percent:" ,
283+ )
284+ rich .print (
285+ "[red]"
286+ + "." .join (diff_vars .prod_path )
287+ + " <> "
288+ + "." .join (diff_vars .dev_path )
289+ + f"[/]\n { diff_url } \n "
290+ + diff_output
291+ + "\n "
292+ )
293+ else :
294+ rich .print (
295+ "[red]"
296+ + "." .join (diff_vars .prod_path )
297+ + " <> "
298+ + "." .join (diff_vars .dev_path )
299+ + f"[/]\n { diff_url } \n "
300+ + "[green]No row differences[/] \n "
301+ )
302+
282303 except BaseException as ex : # Catch KeyboardInterrupt too
283304 error = ex
284305 finally :
@@ -302,15 +323,72 @@ def _cloud_diff(diff_vars: DiffVars) -> None:
302323 send_event_json (event_json )
303324
304325 if error :
305- raise error
326+ logger .error (error )
327+
328+
329+ def _setup_cloud_diff () -> Tuple [str | None ]:
330+ datafold_host = os .environ .get ("DATAFOLD_HOST" )
331+ if datafold_host is None :
332+ datafold_host = "https://app.datafold.com"
333+ datafold_host = datafold_host .rstrip ("/" )
334+ rich .print (f"Cloud datafold host: { datafold_host } \n " )
335+ url = f"{ datafold_host } /api/v1/datadiffs"
336+
337+ api_key = os .environ .get ("DATAFOLD_API_KEY" )
338+ if not api_key :
339+ rich .print ("[red]API key not found, add it as an environment variable called DATAFOLD_API_KEY." )
340+ yes_or_no = Confirm .ask ("Would you like to generate a new API key?" )
341+ if yes_or_no :
342+ webbrowser .open (f"{ datafold_host } /login?next={ datafold_host } /users/me" )
343+ return None , None , None
344+ else :
345+ raise ValueError ("Cannot diff because the API key is not provided" )
346+
347+ return datafold_host , url , api_key
348+
349+
350+ def _cloud_submit_diff (url , payload , headers ) -> str :
351+ response = requests .request ("POST" , url , headers = headers , json = payload , timeout = 30 )
352+ response .raise_for_status ()
353+ response_json = response .json ()
354+ diff_id = str (response_json ["id" ])
355+
356+ if diff_id is None :
357+ raise Exception (f"Api response did not contain a diff_id: { str (response_json )} " )
358+ return diff_id
359+
360+
361+ def _cloud_poll_and_get_summary_results (url , headers ):
362+ summary_results = None
363+ start_time = time .time ()
364+ sleep_interval = 5 # starts at 5 sec
365+ max_sleep_interval = 60
366+ max_wait_time = 300
367+
368+ while not summary_results :
369+ response = requests .request ("GET" , url , headers = headers , timeout = 30 )
370+ response .raise_for_status ()
371+ response_json = response .json ()
372+
373+ if response_json ["status" ] == "success" :
374+ summary_results = response_json
375+ elif response_json ["status" ] == "failed" :
376+ raise Exception (f"Diff failed: { str (response_json )} " )
377+
378+ if time .time () - start_time > max_wait_time :
379+ raise Exception ("Timed out waiting for diff results" )
380+
381+ time .sleep (sleep_interval )
382+ sleep_interval = min (sleep_interval * 2 , max_sleep_interval )
383+
384+ return summary_results
306385
307386
308387class DbtParser :
309- def __init__ (self , profiles_dir_override : str , project_dir_override : str , is_cloud : bool ) -> None :
388+ def __init__ (self , profiles_dir_override : str , project_dir_override : str ) -> None :
310389 self .parse_run_results , self .parse_manifest , self .ProfileRenderer , self .yaml = import_dbt ()
311390 self .profiles_dir = Path (profiles_dir_override or default_profiles_dir ())
312391 self .project_dir = Path (project_dir_override or default_project_dir ())
313- self .is_cloud = is_cloud
314392 self .connection = None
315393 self .project_dict = self .get_project_dict ()
316394 self .manifest_obj = self .get_manifest_obj ()
0 commit comments