# SPDX-License-Identifier: GPL-3.0-or-later
"""
Command line tool for testing SPARQl queries against an endpoint.
"""
import argparse
import contextlib
import os
import subprocess
import sys
import urllib.request
from http import HTTPStatus
from pathlib import Path
from urllib.error import HTTPError
from tqdm.auto import tqdm
from scribe_data.wikidata.check_query.query import QueryExecutionException, QueryFile
from scribe_data.wikidata.check_query.sparql import execute, sparql_context
EXIT_SUCCESS = 0
EXIT_FAILURE = 1
PROJECT_ROOT = "Scribe-Data"
[docs]
def ping(url: str, timeout: int) -> bool:
"""
Test if a URL is reachable.
Parameters
----------
url : str
The URL to test.
timeout : int
The maximum number of seconds to wait for a reply.
Returns
-------
bool
True if connectivity is established, False otherwise.
"""
try:
with urllib.request.urlopen(url, timeout=timeout) as response:
return response.getcode() == HTTPStatus.OK
except (HTTPError, Exception) as err:
print(f"{type(err).__name__}: {str(err)}", file=sys.stderr)
return False
[docs]
def all_queries() -> list[QueryFile]:
"""
All the SPARQL queries in, and below, 'Scribe-Data/'.
Returns
-------
list[QueryFile]
List of SPARQL query files.
"""
parts = Path(__file__).resolve().parts
prj_root_idx = parts.index(PROJECT_ROOT)
prj_root = str(Path(*parts[: prj_root_idx + 1]))
queries: list[QueryFile] = []
for root, _, files in os.walk(prj_root):
for f in files:
file_path = Path(root, f)
if file_path.suffix == ".sparql":
queries.append(QueryFile(file_path))
return queries
[docs]
def changed_queries() -> list[QueryFile] | None:
"""
Find all the SPARQL queries that have changed.
Includes new queries.
Returns
-------
list[QueryFile] | None
List of changed/new SPARQL queries, or None if there's an error.
"""
result = subprocess.run(
(
"git",
"status",
"--short",
),
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
universal_newlines=True,
check=False,
)
if result.returncode != EXIT_SUCCESS:
print(f"ERROR: {result.stderr}", file=sys.stderr)
return None
changed_files = [
Path(norm_line.split(maxsplit=1)[1]).resolve()
for line in result.stdout.split("\n")
if (norm_line := line.strip())
]
return [QueryFile(fpath) for fpath in changed_files if fpath.suffix == ".sparql"]
[docs]
def check_sparql_file(fpath: str) -> Path:
"""
Check meta information of SPARQL query file.
Parameters
----------
fpath : str
The file to validate.
Returns
-------
Path
The validated file path.
"""
path = Path(fpath)
if not path.is_file():
raise argparse.ArgumentTypeError(f"Not a valid file path: {path}")
if path.suffix != ".sparql":
raise argparse.ArgumentTypeError(f"{path} does not have a '.sparql' extension")
return path
[docs]
def check_positive_int(value: str, err_msg: str) -> int:
"""
Ensure 'value' is a positive number.
Parameters
----------
value : str
The value to be validated.
err_msg : str
Used when value fails validation.
Returns
-------
int
The validated number.
Raises
------
argparse.ArgumentTypeError
"""
with contextlib.suppress(ValueError):
number = int(value)
if number >= 1:
return number
raise argparse.ArgumentTypeError(err_msg)
[docs]
def check_limit(limit: str) -> int:
"""
Validate the 'limit' argument.
Parameters
----------
limit : str
The LIMIT to be validated.
Returns
-------
int
The validated LIMIT.
Raises
------
argparse.ArgumentTypeError
"""
return check_positive_int(limit, "LIMIT must be an integer of value 1 or greater.")
[docs]
def check_timeout(timeout: str) -> int:
"""
Validate the 'timeout' argument.
Parameters
----------
timeout : str
The timeout to be validated.
Returns
-------
int
The validated timeout.
Raises
------
argparse.ArgumentTypeError
"""
return check_positive_int(
timeout, "timeout must be an integer of value 1 or greater."
)
[docs]
def main(argv: list[str] | None = None) -> int:
"""
The main function.
Parameters
----------
argv : Optional[list[str]], default=None
If set to None then argparse will use sys.argv as the arguments.
Returns
-------
int
The exit status - 0 - success; any other value - failure.
"""
cli = argparse.ArgumentParser(
description=f"run SPARQL queries from the '{PROJECT_ROOT}' project",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
group = cli.add_mutually_exclusive_group(required=True)
group.add_argument(
"-c",
"--changed",
action="store_true",
help="run only changed/new SPARQL queries",
)
group.add_argument(
"-a", "--all", action="store_true", help="run all SPARQL queries"
)
group.add_argument(
"-f",
"--file",
help="path to a file containing a valid SPARQL query",
type=check_sparql_file,
)
group.add_argument(
"-p",
"--ping",
action="store_true",
default=False,
help="check responsiveness of endpoint",
)
cli.add_argument(
"--timeout",
type=check_timeout,
default=10,
help="maximum number of seconds to wait for a response from the endpoint when 'pinging'",
)
cli.add_argument(
"-e",
"--endpoint",
type=str,
default="https://query.wikidata.org/sparql",
help="URL of the SPARQL endpoint",
)
cli.add_argument(
"-l",
"--limit",
type=check_limit,
default=5,
help="the maximum number or results a query should return",
)
cli.add_argument(
"-v",
"--verbose",
action="store_true",
default=False,
help="increase output verbosity",
)
args = cli.parse_args(argv)
endpoint = args.endpoint
if args.ping:
if ping(endpoint, args.timeout):
print(f"Success: pinged '{endpoint}'")
return EXIT_SUCCESS
print(
f"FAILURE: unable to contact '{endpoint}'. Network problems? "
"Malformed URL? Increase timeout?",
file=sys.stderr,
)
return EXIT_FAILURE
queries = None
if args.all:
queries = all_queries()
elif args.changed:
queries = changed_queries()
elif args.file:
queries = [QueryFile(args.file)]
else:
assert False, "Unknown option"
if queries is None:
return EXIT_FAILURE
context = sparql_context(endpoint)
failures = []
successes = []
for query in tqdm(queries, position=0):
try:
results = execute(query, args.limit, context)
successes.append((query, results))
except QueryExecutionException as err:
failures.append(err)
success_report(successes, display=args.verbose)
error_report(failures)
print("\nSummary")
print(
f"\tQueries run: {len(queries)}, passed: {len(successes)}, failed: {len(failures)}\n"
)
return EXIT_FAILURE if failures else EXIT_SUCCESS
[docs]
def error_report(failures: list[QueryExecutionException]) -> None:
"""
Report failed queries.
Parameters
----------
failures : list[QueryExecutionException]
Failed queries.
"""
if not failures:
return
qword = "query" if len(failures) == 1 else "queries"
print(f"\nFollowing {qword} failed:\n", file=sys.stderr)
for failed_query in failures:
print(failed_query, file=sys.stderr)
[docs]
def success_report(successes: list[tuple[QueryFile, dict]], display: bool) -> None:
"""
Report successful queries.
Parameters
----------
successes : list[tuple[QueryFile, dict]]
Successful queries.
display : bool
Whether there should be an output or not.
"""
if not (display and successes):
return
qword = "query" if len(successes) == 1 else "queries"
print(f"\nFollowing {qword} ran successfully:\n")
for query, results in successes:
print(f"{query.path} returned: {results}")
if __name__ == "__main__":
sys.exit(main())