@@ -129,19 +129,22 @@ def init_logging(self, level=None, session=None):
129
129
session .sparkContext .setLogLevel (level )
130
130
131
131
132
- def init_accumulators (self , sc ):
132
+ def init_accumulators (self , session ):
133
133
"""Register and initialize counters (aka. accumulators).
134
134
Derived classes may use this method to add their own
135
- accumulators but must call super().init_accumulators(sc )
135
+ accumulators but must call super().init_accumulators(session )
136
136
to also initialize counters from base classes."""
137
+ sc = session .sparkContext
137
138
self .records_processed = sc .accumulator (0 )
138
139
self .warc_input_processed = sc .accumulator (0 )
139
140
self .warc_input_failed = sc .accumulator (0 )
140
141
141
- def get_logger (self , spark_context = None ):
142
- """Get logger from SparkContext or (if None) from logging module"""
143
- if spark_context :
144
- return spark_context ._jvm .org .apache .log4j .LogManager \
142
+ def get_logger (self , session = None ):
143
+ """Get logger from SparkSession or (if None) from logging module"""
144
+ if not session :
145
+ session = SparkSession .getActiveSession ()
146
+ if session :
147
+ return session ._jvm .org .apache .log4j .LogManager \
145
148
.getLogger (self .name )
146
149
return logging .getLogger (self .name )
147
150
@@ -156,7 +159,7 @@ def run(self):
156
159
session = builder .getOrCreate ()
157
160
158
161
self .init_logging (self .args .log_level , session )
159
- self .init_accumulators (session . sparkContext )
162
+ self .init_accumulators (session )
160
163
161
164
self .run_job (session )
162
165
@@ -165,27 +168,19 @@ def run(self):
165
168
166
169
session .stop ()
167
170
168
- def log_accumulator (self , sc , acc , descr ):
171
+ def log_accumulator (self , session , acc , descr ):
169
172
"""Log single counter/accumulator"""
170
- self .get_logger (sc ).info (descr .format (acc .value ))
173
+ self .get_logger (session ).info (descr .format (acc .value ))
171
174
172
- def log_accumulators (self , sc ):
175
+ def log_accumulators (self , session ):
173
176
"""Log counters/accumulators, see `init_accumulators`."""
174
- self .log_accumulator (sc , self .warc_input_processed ,
177
+ self .log_accumulator (session , self .warc_input_processed ,
175
178
'WARC/WAT/WET input files processed = {}' )
176
- self .log_accumulator (sc , self .warc_input_failed ,
179
+ self .log_accumulator (session , self .warc_input_failed ,
177
180
'WARC/WAT/WET input files failed = {}' )
178
- self .log_accumulator (sc , self .records_processed ,
181
+ self .log_accumulator (session , self .records_processed ,
179
182
'WARC/WAT/WET records processed = {}' )
180
183
181
- def log_aggregator (self , sc , agg , descr ):
182
- """Deprecated, use log_accumulator."""
183
- self .log_accumulator (sc , agg , descr )
184
-
185
- def log_aggregators (self , sc ):
186
- """Deprecated, use log_accumulators."""
187
- self .log_accumulators (sc )
188
-
189
184
@staticmethod
190
185
def reduce_by_key_func (a , b ):
191
186
return a + b
@@ -205,7 +200,7 @@ def run_job(self, session):
205
200
.options (** self .get_output_options ()) \
206
201
.saveAsTable (self .args .output )
207
202
208
- self .log_accumulators (session . sparkContext )
203
+ self .log_accumulators (session )
209
204
210
205
def process_warcs (self , _id , iterator ):
211
206
s3pattern = re .compile ('^s3://([^/]+)/(.+)' )
@@ -342,19 +337,19 @@ def add_arguments(self, parser):
342
337
def load_table (self , session , table_path , table_name ):
343
338
parquet_reader = session .read .format ('parquet' )
344
339
if self .args .table_schema is not None :
345
- self .get_logger (session . sparkContext ).info (
340
+ self .get_logger (session ).info (
346
341
"Reading table schema from {}" .format (self .args .table_schema ))
347
342
with open (self .args .table_schema , 'r' ) as s :
348
343
schema = StructType .fromJson (json .loads (s .read ()))
349
344
parquet_reader = parquet_reader .schema (schema )
350
345
df = parquet_reader .load (table_path )
351
346
df .createOrReplaceTempView (table_name )
352
- self .get_logger (session . sparkContext ).info (
347
+ self .get_logger (session ).info (
353
348
"Schema of table {}:\n {}" .format (table_name , df .schema ))
354
349
355
350
def execute_query (self , session , query ):
356
351
sqldf = session .sql (query )
357
- self .get_logger (session . sparkContext ).info ("Executing query: {}" .format (query ))
352
+ self .get_logger (session ).info ("Executing query: {}" .format (query ))
358
353
sqldf .explain ()
359
354
return sqldf
360
355
@@ -364,11 +359,11 @@ def load_dataframe(self, session, partitions=-1):
364
359
sqldf .persist ()
365
360
366
361
num_rows = sqldf .count ()
367
- self .get_logger (session . sparkContext ).info (
362
+ self .get_logger (session ).info (
368
363
"Number of records/rows matched by query: {}" .format (num_rows ))
369
364
370
365
if partitions > 0 :
371
- self .get_logger (session . sparkContext ).info (
366
+ self .get_logger (session ).info (
372
367
"Repartitioning data to {} partitions" .format (partitions ))
373
368
sqldf = sqldf .repartition (partitions )
374
369
sqldf .persist ()
@@ -384,7 +379,7 @@ def run_job(self, session):
384
379
.options (** self .get_output_options ()) \
385
380
.saveAsTable (self .args .output )
386
381
387
- self .log_accumulators (session . sparkContext )
382
+ self .log_accumulators (session )
388
383
389
384
390
385
class CCIndexWarcSparkJob (CCIndexSparkJob ):
@@ -450,7 +445,7 @@ def load_dataframe(self, session, partitions=-1):
450
445
sqldf = reader .load (self .args .input )
451
446
452
447
if partitions > 0 :
453
- self .get_logger (sc ).info (
448
+ self .get_logger (session ).info (
454
449
"Repartitioning data to {} partitions" .format (partitions ))
455
450
sqldf = sqldf .repartition (partitions )
456
451
@@ -523,4 +518,4 @@ def run_job(self, session):
523
518
.options (** self .get_output_options ()) \
524
519
.saveAsTable (self .args .output )
525
520
526
- self .log_accumulators (session . sparkContext )
521
+ self .log_accumulators (session )
0 commit comments