Skip to content

Commit 94d7367

Browse files
Use SparkSession instead of SQLContext
- also for logging-related methods
1 parent 113ab9b commit 94d7367

File tree

3 files changed

+47
-50
lines changed

3 files changed

+47
-50
lines changed

cc_index_word_count.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -16,18 +16,19 @@ class CCIndexWordCountJob(WordCountJob, CCIndexWarcSparkJob):
1616
records_parsing_failed = None
1717
records_non_html = None
1818

19-
def init_accumulators(self, sc):
20-
super(CCIndexWordCountJob, self).init_accumulators(sc)
19+
def init_accumulators(self, session):
20+
super(CCIndexWordCountJob, self).init_accumulators(session)
2121

22+
sc = session.sparkContext
2223
self.records_parsing_failed = sc.accumulator(0)
2324
self.records_non_html = sc.accumulator(0)
2425

25-
def log_accumulators(self, sc):
26-
super(CCIndexWordCountJob, self).log_accumulators(sc)
26+
def log_accumulators(self, session):
27+
super(CCIndexWordCountJob, self).log_accumulators(session)
2728

28-
self.log_accumulator(sc, self.records_parsing_failed,
29+
self.log_accumulator(session, self.records_parsing_failed,
2930
'records failed to parse = {}')
30-
self.log_accumulator(sc, self.records_non_html,
31+
self.log_accumulator(session, self.records_non_html,
3132
'records not HTML = {}')
3233

3334
@staticmethod

sparkcc.py

+25-30
Original file line numberDiff line numberDiff line change
@@ -129,19 +129,22 @@ def init_logging(self, level=None, session=None):
129129
session.sparkContext.setLogLevel(level)
130130

131131

132-
def init_accumulators(self, sc):
132+
def init_accumulators(self, session):
133133
"""Register and initialize counters (aka. accumulators).
134134
Derived classes may use this method to add their own
135-
accumulators but must call super().init_accumulators(sc)
135+
accumulators but must call super().init_accumulators(session)
136136
to also initialize counters from base classes."""
137+
sc = session.sparkContext
137138
self.records_processed = sc.accumulator(0)
138139
self.warc_input_processed = sc.accumulator(0)
139140
self.warc_input_failed = sc.accumulator(0)
140141

141-
def get_logger(self, spark_context=None):
142-
"""Get logger from SparkContext or (if None) from logging module"""
143-
if spark_context:
144-
return spark_context._jvm.org.apache.log4j.LogManager \
142+
def get_logger(self, session=None):
143+
"""Get logger from SparkSession or (if None) from logging module"""
144+
if not session:
145+
session = SparkSession.getActiveSession()
146+
if session:
147+
return session._jvm.org.apache.log4j.LogManager \
145148
.getLogger(self.name)
146149
return logging.getLogger(self.name)
147150

@@ -156,7 +159,7 @@ def run(self):
156159
session = builder.getOrCreate()
157160

158161
self.init_logging(self.args.log_level, session)
159-
self.init_accumulators(session.sparkContext)
162+
self.init_accumulators(session)
160163

161164
self.run_job(session)
162165

@@ -165,27 +168,19 @@ def run(self):
165168

166169
session.stop()
167170

168-
def log_accumulator(self, sc, acc, descr):
171+
def log_accumulator(self, session, acc, descr):
169172
"""Log single counter/accumulator"""
170-
self.get_logger(sc).info(descr.format(acc.value))
173+
self.get_logger(session).info(descr.format(acc.value))
171174

172-
def log_accumulators(self, sc):
175+
def log_accumulators(self, session):
173176
"""Log counters/accumulators, see `init_accumulators`."""
174-
self.log_accumulator(sc, self.warc_input_processed,
177+
self.log_accumulator(session, self.warc_input_processed,
175178
'WARC/WAT/WET input files processed = {}')
176-
self.log_accumulator(sc, self.warc_input_failed,
179+
self.log_accumulator(session, self.warc_input_failed,
177180
'WARC/WAT/WET input files failed = {}')
178-
self.log_accumulator(sc, self.records_processed,
181+
self.log_accumulator(session, self.records_processed,
179182
'WARC/WAT/WET records processed = {}')
180183

181-
def log_aggregator(self, sc, agg, descr):
182-
"""Deprecated, use log_accumulator."""
183-
self.log_accumulator(sc, agg, descr)
184-
185-
def log_aggregators(self, sc):
186-
"""Deprecated, use log_accumulators."""
187-
self.log_accumulators(sc)
188-
189184
@staticmethod
190185
def reduce_by_key_func(a, b):
191186
return a + b
@@ -205,7 +200,7 @@ def run_job(self, session):
205200
.options(**self.get_output_options()) \
206201
.saveAsTable(self.args.output)
207202

208-
self.log_accumulators(session.sparkContext)
203+
self.log_accumulators(session)
209204

210205
def process_warcs(self, _id, iterator):
211206
s3pattern = re.compile('^s3://([^/]+)/(.+)')
@@ -342,19 +337,19 @@ def add_arguments(self, parser):
342337
def load_table(self, session, table_path, table_name):
343338
parquet_reader = session.read.format('parquet')
344339
if self.args.table_schema is not None:
345-
self.get_logger(session.sparkContext).info(
340+
self.get_logger(session).info(
346341
"Reading table schema from {}".format(self.args.table_schema))
347342
with open(self.args.table_schema, 'r') as s:
348343
schema = StructType.fromJson(json.loads(s.read()))
349344
parquet_reader = parquet_reader.schema(schema)
350345
df = parquet_reader.load(table_path)
351346
df.createOrReplaceTempView(table_name)
352-
self.get_logger(session.sparkContext).info(
347+
self.get_logger(session).info(
353348
"Schema of table {}:\n{}".format(table_name, df.schema))
354349

355350
def execute_query(self, session, query):
356351
sqldf = session.sql(query)
357-
self.get_logger(session.sparkContext).info("Executing query: {}".format(query))
352+
self.get_logger(session).info("Executing query: {}".format(query))
358353
sqldf.explain()
359354
return sqldf
360355

@@ -364,11 +359,11 @@ def load_dataframe(self, session, partitions=-1):
364359
sqldf.persist()
365360

366361
num_rows = sqldf.count()
367-
self.get_logger(session.sparkContext).info(
362+
self.get_logger(session).info(
368363
"Number of records/rows matched by query: {}".format(num_rows))
369364

370365
if partitions > 0:
371-
self.get_logger(session.sparkContext).info(
366+
self.get_logger(session).info(
372367
"Repartitioning data to {} partitions".format(partitions))
373368
sqldf = sqldf.repartition(partitions)
374369
sqldf.persist()
@@ -384,7 +379,7 @@ def run_job(self, session):
384379
.options(**self.get_output_options()) \
385380
.saveAsTable(self.args.output)
386381

387-
self.log_accumulators(session.sparkContext)
382+
self.log_accumulators(session)
388383

389384

390385
class CCIndexWarcSparkJob(CCIndexSparkJob):
@@ -450,7 +445,7 @@ def load_dataframe(self, session, partitions=-1):
450445
sqldf = reader.load(self.args.input)
451446

452447
if partitions > 0:
453-
self.get_logger(sc).info(
448+
self.get_logger(session).info(
454449
"Repartitioning data to {} partitions".format(partitions))
455450
sqldf = sqldf.repartition(partitions)
456451

@@ -523,4 +518,4 @@ def run_job(self, session):
523518
.options(**self.get_output_options()) \
524519
.saveAsTable(self.args.output)
525520

526-
self.log_accumulators(session.sparkContext)
521+
self.log_accumulators(session)

wat_extract_links.py

+15-14
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def process_redirect(self, record, stream, http_status_line):
139139
try:
140140
redir_to = redir_to.decode('utf-8')
141141
except UnicodeError as e:
142-
self.get_logger().warn(
142+
self.get_logger().warning(
143143
'URL with unknown encoding: {} - {}'.format(
144144
redir_to, e))
145145
return
@@ -242,9 +242,10 @@ def get_links(self, url, record):
242242
url, e))
243243
self.records_failed.add(1)
244244

245-
def init_accumulators(self, sc):
246-
super(ExtractLinksJob, self).init_accumulators(sc)
245+
def init_accumulators(self, session):
246+
super(ExtractLinksJob, self).init_accumulators(session)
247247

248+
sc = session.sparkContext
248249
self.records_failed = sc.accumulator(0)
249250
self.records_non_html = sc.accumulator(0)
250251
self.records_response = sc.accumulator(0)
@@ -254,24 +255,24 @@ def init_accumulators(self, sc):
254255
self.records_response_robotstxt = sc.accumulator(0)
255256
self.link_count = sc.accumulator(0)
256257

257-
def log_accumulators(self, sc):
258-
super(ExtractLinksJob, self).log_accumulators(sc)
258+
def log_accumulators(self, session):
259+
super(ExtractLinksJob, self).log_accumulators(session)
259260

260-
self.log_accumulator(sc, self.records_response,
261+
self.log_accumulator(session, self.records_response,
261262
'response records = {}')
262-
self.log_accumulator(sc, self.records_failed,
263+
self.log_accumulator(session, self.records_failed,
263264
'records failed to process = {}')
264-
self.log_accumulator(sc, self.records_non_html,
265+
self.log_accumulator(session, self.records_non_html,
265266
'records not HTML = {}')
266-
self.log_accumulator(sc, self.records_response_wat,
267+
self.log_accumulator(session, self.records_response_wat,
267268
'response records WAT = {}')
268-
self.log_accumulator(sc, self.records_response_warc,
269+
self.log_accumulator(session, self.records_response_warc,
269270
'response records WARC = {}')
270-
self.log_accumulator(sc, self.records_response_redirect,
271+
self.log_accumulator(session, self.records_response_redirect,
271272
'response records redirects = {}')
272-
self.log_accumulator(sc, self.records_response_robotstxt,
273+
self.log_accumulator(session, self.records_response_robotstxt,
273274
'response records robots.txt = {}')
274-
self.log_accumulator(sc, self.link_count,
275+
self.log_accumulator(session, self.link_count,
275276
'non-unique link pairs = {}')
276277

277278
def run_job(self, session):
@@ -480,7 +481,7 @@ def process_robotstxt(self, record, stream, _http_status_line):
480481
if thost and src_host and src_host != thost:
481482
yield src_host, thost
482483
except UnicodeError as e:
483-
self.get_logger().warn(
484+
self.get_logger().warning(
484485
'URL with unknown encoding: {} - {}'.format(
485486
sitemap, e))
486487
line = stream.readline()

0 commit comments

Comments
 (0)