def skip_update_job_pod_name(dag):
"""
:param dag: Airflow DAG
:return: Dummy operator to skip update pod name
"""
return DummyOperator(task_id="skip_update_job_pod_name", dag=dag)
def update_pod_name_branch_operator(dag: DAG, job_id: str):
"""branch operator to update pod name."""
return BranchPythonOperator(
dag=dag,
trigger_rule="all_done",
task_id="update_pod_name",
python_callable=update_pod_name_func,
op_kwargs={"job_id": job_id},
)
def update_pod_name_func(job_id: Optional[str]) -> str:
"""function for update pod name."""
return "update_job_pod_name" if job_id else "skip_update_pod_name"
def update_job_pod_name(dag: DAG, job_id: str, process_name: str) -> MySqlOperator:
"""
:param dag: Airflow DAG
:param job_id: Airflow Job ID
:param process_name: name of the current running process
:return: MySqlOperator to update Airflow job ID
"""
return MySqlOperator(
task_id="update_job_pod_name",
mysql_conn_id="semantic-search-airflow-sdk",
autocommit=True,
sql=[
f"""
INSERT INTO airflow.Pod (job_id, pod_name, task_name)
SELECT * FROM (SELECT '{job_id}', '{xcom_pull("pod_name")}', '{process_name}') AS temp
WHERE NOT EXISTS (
SELECT pod_name FROM airflow.Pod WHERE pod_name = '{{{{ ti.xcom_pull(key="pod_name") }}}}'
) LIMIT 1;
"""
],
task_concurrency=1,
dag=dag,
trigger_rule="all_done",
)
def create_k8s_pod_operator_without_volume(dag: DAG,
job_id: int,
....varaible) -> TaskGroup:
"""
Create task group for k8 operator without volume
"""
with TaskGroup(group_id="k8s_pod_operator_without_volume", dag=dag) as eks_without_volume_group:
emit_pod_name_branch = update_pod_name_branch_operator(dag=dag, job_id=job_id)
update_pod_name = update_job_pod_name(dag=dag, job_id=job_id, process_name=process_name)
skip_update_pod_name = skip_update_job_pod_name(dag=dag)
emit_pod_name_branch >> [update_pod_name, skip_update_pod_name]
return eks_without_volume_group
I update the code based on the comment, I am curious how does the taskgroup work with branch operator I will get this when I try to do this
airflow.exceptions.AirflowException: Branch callable must return valid task_ids. Invalid tasks found: {'update_job_pod_name'}
You can use BranchPythonOperator that get the value and return which the name of task to run in any condition.
def choose_job_func(job_id):
if job_id:
return "update_pod_name_rds"
choose_update_job =BranchPythonOperator(task_id="choose_update_job", python_callable=choose_job_func,
op_kwargs={"job_id": "{{ params.job_id }}"})
or, in task flow api it would look like this :
#task.branch
def choose_update_job(job_id):
if job_id:
return "update_pod_name_rds"
Full Dag Example :
with DAG(
dag_id="test_dag",
start_date=datetime(2022, 1, 1),
schedule_interval=None,
render_template_as_native_obj=True,
params={
"job_id": Param(default=None, type=["null", "string"])
},
tags=["test"],) as dag:
def update_job_pod_name(job_id: str, process_name: str):
return MySqlOperator(
task_id="update_pod_name_rds",
mysql_conn_id="semantic-search-airflow-sdk",
autocommit=True,
sql=[
f"""
INSERT INTO airflow.Pod (job_id, pod_name, task_name)
SELECT * FROM (SELECT '{job_id}', '{xcom_pull("pod_name")}', '{process_name}') AS temp
WHERE NOT EXISTS (
SELECT pod_name FROM airflow.Pod WHERE pod_name = '{{{{ ti.xcom_pull(key="pod_name") }}}}'
) LIMIT 1;
"""
],
task_concurrency=1,
dag=dag,
trigger_rule="all_done",
)
#task.branch
def choose_update_job(job_id):
print(job_id)
if job_id:
return "update_pod_name_rds"
return "do_nothing"
sql_task = update_job_pod_name(
"{{ params.job_id}}",
"process_name",
)
do_nothing = EmptyOperator(task_id="do_nothing")
start_dag = EmptyOperator(task_id="start")
end_dag = EmptyOperator(task_id="end", trigger_rule=TriggerRule.ONE_SUCCESS)
(start_dag >> choose_update_job("{{ params.job_id }}") >> [sql_task, do_nothing] >> end_dag)
Related
I have a DAG that queries a table, pulling data from it, and which also uses a ShortCircuitOperator to check if the DAG needs to run, which is also based on a BQ table name. The issue is this currently queries the table every time Airflow refreshes. The table is small (each query is less than 1 kb) but I'm concerned this will get more expensive as its scaled up. Is there a way to only query this table each DAG run instead?
Here's a code snippet to show what's going on:
client = bigquery.Client()
def create_query_lists(list_type):
query_job = client.query(
"""
SELECT filename
FROM `requests`
"""
)
results = query_job.result()
results_list = []
for row in results:
results_list.append(row.filename)
return results_list
def check_contents():
if len(create_query_lists()) == 0:
raise ValueError('Nothing to do')
return False
else:
print("There's stuff to do")
return True
#Create task to check if data being pulled is empty, if so fail so other tasks don't run
check_list = ShortCircuitOperator(
task_id="check_column_not_empty",
provide_context=True,
python_callable=check_list_contents
)
check_list #do subsequent tasks which use the same function
If I correctly understood your need, you want to execute tasks only if the result of SQL query is not empty.
In this case you can also use BranchPythonOperator, example :
import airflow
from airflow.operators.dummy import DummyOperator
from airflow.operators.python import BranchPythonOperator
from google.cloud import bigquery
client = bigquery.Client()
def create_query_lists(list_type):
query_job = client.query(
"""
SELECT filename
FROM `requests`
"""
)
results = query_job.result()
results_list = []
for row in results:
results_list.append(row.filename)
return results_list
def check_contents():
if len(create_query_lists()) == 0:
return 'KO'
else:
return 'OK'
with airflow.DAG(
"your_dag",
schedule_interval=None) as dag:
branching = BranchPythonOperator(
task_id='file_exists',
python_callable=check_contents,
provide_context=True,
op_kwargs={
'param': 'param'
},
dag=dag
)
ok = DummyOperator(task_id='OK', dag=dag)
ko = DummyOperator(task_id='KO', dag=dag)
fake_task = DummyOperator(task_id='fake_task', dag=dag)
(branching >>
ok >>
fake_task)
branching >> ko
The BranchPythonOperator executes the query, if the result is not empty, it returns OK, otherwise KO
We create 2 DummyOperator one for OK, the other for KO (2 branches)
Depending on the result, we will go to the OK or KO branch
The KO branch will finish the DAG without other tasks
The OK branch will continue the DAG with tasks that follow (fake_task) in my example
I have a use case wherein I need to extract the value using xcom_pull.
Version of Airflow : 2.3.4
Composer version : 2.1.1
live_fw_num="{{ ti.xcom_pull(dag_id='" + DAG_ID + "',task_ids='get_fw_of_month')[0][0] }}")
The output coming out is 1. Images Attached
The first image shows the xcom tab value
The second image shows the value when extracted in a live_fw_num variable
Code :
today = datetime.date.today()
def function_1(table2,table3,live_fw_num):
if live_fw_num == '1' : ####( Tried with Integer value as well)
<do something>
else:
<do something else>
with dag:
task_1 = PythonOperator(
task_id='get_fw_of_month',
python_callable=get_data,
op_kwargs={'sql': task_1_func(tb1=<some table name>,
curr_dte=today,
)
}
)
task_3 = PythonOperator(
task_id = 'task3',
python_callable = function_1,
op_kwargs={'table2': <table name >,
'table3': <table name >,
'live_fw_num': "{{ ti.xcom_pull(dag_id='" + DAG_ID + "',task_ids='get_fw_of_month')[0][0] }}"
}
)
task_1 >> task_3
But when I am comparing this value to a static variable using if-else clause, it goes to else consition instead of if condition , even though the output value of live_fw_num is 1
I tested your DAG and your code and it worked on my side :
The only difference, I mocked the get_data method used by the first task, to return 1 as String.
import datetime
import logging
import airflow
from airflow.operators.python import PythonOperator
from integration_ocd.config import settings
today = datetime.date.today()
def get_data():
return '1'
def function_1(table2, table3, live_fw_num):
if live_fw_num == '1':
logging.info('############################################<do something>')
else:
logging.info('############################################<do something else>')
with airflow.DAG(
'dag_test_xcom',
default_args=your_default_dag_args,
schedule_interval=None) as dag:
task_1 = PythonOperator(
task_id='get_fw_of_month',
python_callable=get_data
)
task_3 = PythonOperator(
task_id='task3',
python_callable=function_1,
op_kwargs={
'table2': 'table1',
'table3': 'table2',
'live_fw_num': "{{ ti.xcom_pull(dag_id='dag_test_xcom',task_ids='get_fw_of_month')[0][0] }}"
}
)
task_1 >> task_3
Can I pass a return value from DatabricksRunNowOperator using xcom or any other method ? I just want to return back "date" after my 1st databricks operator has finished and pass it on to the dependent task
for eg:
I want to pass the return value of the verification_run to the insert_run and workspace_run. Usually we can use a xcom_pull and xcom_push to do it in python. But not sure how to make two notebooks talk to each other
from airflow import DAG
from airflow.providers.databricks.operators.databricks import DatabricksRunNowOperator
from airflow.utils.dates import days_ago
default_args = {
'owner': 'airflow'
}
with DAG('databricks_dag',
start_date = days_ago(2),
schedule_interval = None,
default_args = default_args
) as dag:
verification_run = DatabricksRunNowOperator(
task_id = 'verification_task',
databricks_conn_id = 'databricks_default',
job_id = '-----'
)
insert_run = DatabricksRunNowOperator(
task_id = 'insert_task',
databricks_conn_id = 'databricks_default',
job_id = '-----'
)
workspace_run = DatabricksRunNowOperator(
task_id = 'workspace_task',
databricks_conn_id = 'databricks_default',
job_id = '------'
)
verification_run >> [insert_run,workspace_run]
I need to execute multiple pipelines in parallel(each pipeline is going to do the same logic but with different inputs) and want to get the metric count after executing the pipeline. The problem is while creating the template for the pipeline since it creates only one template file for each pipeline so it overrides the old template and finally get the template file for last pipeline alone. Basically we are doing this to get the row count fetched from Bigquery and bigtable(written rows) are same or not for different data sources.
Note: I am using multithreading to schedule the pipelines in parallel.
def getrow_count(self, pipeline, metric_name):
if not hasattr(pipeline.result, 'has_job'):
read_filter = MetricsFilter().with_name(metric_name)
query_result = pipeline.result.metrics().query(read_filter)
if query_result['counters']:
read_counter = query_result['counters'][0]
print(f"Row count for metric {metric_name} is {read_counter.committed}")
return read_counter.committed
def run_pipeline(self, query_text, schema, table_type):
job_name = table_type.replace('_','') + datetime.datetime.now().strftime("%Y%m%d")
self.options.view_as(GoogleCloudOptions).job_name = str(job_name)
self.options.view_as(GoogleCloudOptions).temp_location = 'gs://dataflow_storage_bq_bt/dataflow_test/tmp/' + table_type
self.options.view_as(GoogleCloudOptions).staging_location = 'gs://dataflow_storage_bq_bt/dataflow_test/tmp/'+ table_type
print(self.options.view_as(GoogleCloudOptions).staging_location)
with beam.Pipeline(options=self.options) as pipeline:
data_collection = pipeline | f"Get {table_type} from BigQuery" >> beam.io.Read(ReadFromBigQuery(query=query_text.get(),
use_standard_sql=True))
data_collection \
| f"Get {table_type} list of direct_row's " >> beam.ParDo(CreateRowFn(schema)) \
| f"Get {table_type} single direct row" >> beam.ParDo(GetRowFn()) \
| f"Write {table_type} To BT" >> WriteToBigTable(project_id=self.config_data["gcp_config"]["bt_project"],
instance_id=self.config_data["gcp_config"]["bt_instance_id"],
table_id=self.config_data["gcp_config"]["bt_table_id"])
bigquery_count = self.getrow_count(pipeline,'bigquery_row')
bigtable_count = self.getrow_count(pipeline,'Written Row')
if bigquery_count is None or bigtable_count is None:
print(f"No daily upload data for {table_type}")
elif bigquery_count == bigtable_count:
print(f"All daily upload data for {table_type} moved to bigtable from bigquery")
else:
raise ValueError("Row count mismatch; check the pipeline for {table_type}")
def get_query_text(self, file_path):
query_text_read_output = self.get_blob_data(file_path)
query_text = query_text_read_output.decode('utf-8')
return query_text
def get_blob_data(self, file_path):
blob = self.bucket.get_blob(file_path)
data = blob.download_as_string()
return data
def run(self):
self.set_options()
sql_config = self.config_data["sql_config"]
querytext1= self.options.view_as(DailyUploadOptions).test1
querytext2= self.options.view_as(DailyUploadOptions).test2
querytext3= self.options.view_as(DailyUploadOptions).test3
querytext4= self.options.view_as(DailyUploadOptions).test4
querytext5= self.options.view_as(DailyUploadOptions).test5
querytext6= self.options.view_as(DailyUploadOptions).test6
querytext7 = self.options.view_as(DailyUploadOptions).test7
Thread(target=self.run_pipeline,
args=(querytext1,
sql_config['test1_config']['schema'],
'feature1')).start()
time.sleep(100) #sleep time should be there or else dataflow job will get failed due to same job name pciked in subsequent job
Thread(target=self.run_pipeline,
args=(querytext2,
sql_config['test2_config']['schema'],
'feature2')).start()
time.sleep(100)
Thread(target=self.run_pipeline,
args=(querytext3,
sql_config['test3_config']['schema'],
'feature3')).start()
time.sleep(100)
Thread(target=self.run_pipeline,
args=(querytext4,
sql_config['test4_config']['schema'],
'feature4')).start()
time.sleep(100)
Thread(target=self.run_pipeline,
args=(querytext5,
sql_config['test5_config']['schema'],
'feature5')).start()
time.sleep(100)
Thread(target=self.run_pipeline,
args=(querytext6,
sql_config['test6_config']['schema'],
'feature6')).start()
time.sleep(100)
Thread(target=self.run_pipeline,
args=(querytext7,
sql_config['test7_config']['schema'],
'feature7')).start()
class CreateRowFn(beam.DoFn):
def __init__(self, schema):
self.schema = schema
self.bg_count = Metrics.counter('Bigquery', 'bigquery_row')
def process(self, key):
self.bg_count.inc()
direct_rows = []
data = json.loads(key['data'], strict=False)
direct_row = row.DirectRow(row_key=data["row_key"])
for table_type in self.schema:
for column_family in self.schema[table_type]['columns']:
for column in self.schema[table_type]['columns'][column_family]:
direct_row.set_cell(
column_family,
column,
json.dumps(data.get(column, {})),
datetime.datetime.fromtimestamp(0.0))
direct_rows.append(direct_row)
return [direct_rows]
class GetRowFn(beam.DoFn):
def process(self, row_list):
for row in row_list:
return [row]
Can you try this. Define the pipeline before, and then use it in parallel. In addition, don't worry about parallelism, let Dataflow managing it. You will disturb its own thread manager if you manage yourselves the thread.
def run_pipeline(self, pipeline, query_text, schema, table_type):
data_collection = pipeline | f"Get {table_type} from BigQuery" >> beam.io.Read(ReadFromBigQuery(query=query_text.get(), use_standard_sql=True))
.......
def run(self):
self.set_options()
sql_config = self.config_data["sql_config"]
querytext1= self.options.view_as(DailyUploadOptions).test1
querytext2= self.options.view_as(DailyUploadOptions).test2
querytext3= self.options.view_as(DailyUploadOptions).test3
querytext4= self.options.view_as(DailyUploadOptions).test4
querytext5= self.options.view_as(DailyUploadOptions).test5
querytext6= self.options.view_as(DailyUploadOptions).test6
querytext7 = self.options.view_as(DailyUploadOptions).test7
job_name = table_type.replace('_','') + datetime.datetime.now().strftime("%Y%m%d")
self.options.view_as(GoogleCloudOptions).job_name = str(job_name)
self.options.view_as(GoogleCloudOptions).temp_location = 'gs://dataflow_storage_bq_bt/dataflow_test/tmp/' + table_type
self.options.view_as(GoogleCloudOptions).staging_location = 'gs://dataflow_storage_bq_bt/dataflow_test/tmp/'+ table_type
print(self.options.view_as(GoogleCloudOptions).staging_location)
pipeline = beam.Pipeline(options=self.options)
run_pipeline(pipeline, querytext1,
sql_config['test1_config']['schema'],
'feature1')
run_pipeline(pipeline, querytext2,
sql_config['test2_config']['schema'],
'feature2')
run_pipeline(pipeline, querytext3,
sql_config['test3_config']['schema'],
'feature3')
run_pipeline(pipeline, querytext4,
sql_config['test4_config']['schema'],
'feature4')
run_pipeline(pipeline, querytext5,
sql_config['test5_config']['schema'],
'feature5')
run_pipeline(pipeline, querytext6,
sql_config['test6_config']['schema'],
'feature6')
run_pipeline(pipeline, querytext7,
sql_config['test7_config']['schema'],
'feature7')
pipeline.run()
My pipeline has the following simple JSON input
{"mac": "KC:FC:48:AE:F6:94", "status": 8, "datetime": "2015-07-13T21:15:02Z"}
The output should basically go to a BigQuery table with 3 columns (mac, status and datetime) with their corresponding values
My Pipeline looks as follows:
# -*- coding: utf-8 -*-
import os, json, logging, argparse, datetime, apache_beam as beam
from google.cloud import error_reporting
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.options.pipeline_options import StandardOptions
from apache_beam.options.pipeline_options import GoogleCloudOptions
GOOGLE_PUBSUB_CHANNEL = 'projects/project-name/topics/topic-name'
GOOGLE_BIGQUERY_TABLE = 'bq-table'
GOOGLE_DATASET_ID = 'bq-dataset'
GOOGLE_PROJECT_ID = 'project-name'
class GoogleBigQuery():
client_error = error_reporting.Client()
#staticmethod
def get_schema_table(schema):
bigquery_schema = []
for key in range(len(schema)):
bigquery_schema.append('{}:{}'.format(schema[key].get('bigquery_field_name'), schema[key].get('bigquery_field_type')))
return ','.join(bigquery_schema)
fields_contract = (
{ 'bigquery_field_name': 'datetime', 'bigquery_field_type': 'STRING' },
{ 'bigquery_field_name': 'mac', 'bigquery_field_type': 'STRING' },
{ 'bigquery_field_name': 'status', 'bigquery_field_type': 'INTEGER' }
)
def parse_pubsub(line):
record = json.loads(line)
logging.info(record)
return record
class FilterStatus1(beam.DoFn):
def status_filter_1(self, data):
for r in data:
print(r)
logging.info(r)
if r["status"] == 1:
print(r)
logging.info(r)
yield r
def run(argv=None):
parser = argparse.ArgumentParser()
known_args, pipeline_args = parser.parse_known_args(argv)
pipeline_parameters = [
'--runner', 'DirectRunner'
, '--staging_location', 'gs://bucket/staging'
, '--temp_location', 'gs://bucket/temp'
, '--autoscaling_algorithm', 'THROUGHPUT_BASED' #'NONE' to disable autoscaling
, '--num_workers', '1'
, '--max_num_workers', '2'
, '--disk_size_gb', '30'
, '--worker_machine_type', 'n1-standard-1'
]
pipeline_options = PipelineOptions(pipeline_parameters)
pipeline_options.view_as(StandardOptions).streaming = True
pipeline_options.view_as(GoogleCloudOptions).job_name = os.path.basename(__file__).split('.')[0].replace('_', '-')
pipeline_options.view_as(GoogleCloudOptions).project = GOOGLE_PROJECT_ID
with beam.Pipeline(options=pipeline_options, argv=pipeline_parameters) as p:
# Read the pubsub topic into a PCollection.
lines = (
p
| 'ReadPubSubMessage' >> beam.io.ReadFromPubSub(GOOGLE_PUBSUB_CHANNEL).with_output_types(bytes)
| 'Decode UTF-8' >> beam.Map(lambda x: x.decode('utf-8'))
| 'ParsePubSub' >> beam.Map(parse_pubsub)
)
(
lines | 'Filter Status 1' >> beam.ParDo(FilterStatus1())
| 'WriteToBigQueryStatus1' >> beam.io.WriteToBigQuery(
GOOGLE_BIGQUERY_TABLE
, project=GOOGLE_PROJECT_ID
, dataset=GOOGLE_DATASET_ID
, schema=GoogleBigQuery.get_schema_table(fields_contract)
, create_disposition=beam.io.BigQueryDisposition.CREATE_IF_NEEDED
, write_disposition=beam.io.BigQueryDisposition.WRITE_APPEND
#, write_disposition=beam.io.BigQueryDisposition.WRITE_TRUNCATE
)
)
logging.info('Pipeline finished')
result = p.run()
result.wait_until_finish()
if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
run()
I'm getting the following error:
RuntimeError: NotImplementedError [while running 'Filter Status 1']
My goal here is to filter the status column and when the value is 1 to stream it into BQ.
Thanks in advance for helping me out.
You can try a filtering approach using FlatMap to do such things.
First, define a filtering method:
def FilterStatus1(row):
if row["status"] == 1:
yield row
Then you can apply like:
lines = lines | beam.FlatMap(FilterStatus1) | 'WriteToBigQueryStatus1' ...
Also, try breaking up your code into chunks or explicitly assigned steps. This giant transformation, mappings and filterings happening in a single row usually turn your code into a black-box.
Hope it helps. Thanks.
I fixed my code this way
class FilterStatus1(beam.DoFn):
def process(self, data):
if data["status"] == 1:
result = [{"datetime":data["datetime"], "mac":data["mac"], "status":data["status"]}]
logging.info(result)
return result