]> Nutra Git (v2) - nutratech/cli.git/commitdiff
wip refactor sql() to return headers, and more
authorShane Jaroch <chown_tee@proton.me>
Sat, 2 Mar 2024 15:45:02 +0000 (10:45 -0500)
committerShane Jaroch <chown_tee@proton.me>
Sat, 2 Mar 2024 15:46:00 +0000 (10:46 -0500)
ntclient/__init__.py
ntclient/argparser/funcs.py
ntclient/persistence/sql/__init__.py
ntclient/persistence/sql/nt/__init__.py
ntclient/persistence/sql/nt/funcs.py
ntclient/persistence/sql/usda/__init__.py
ntclient/persistence/sql/usda/funcs.py
ntclient/services/bugs.py
ntclient/services/usda.py

index 5242f3fade5bb37e7c6f3f7ed03dba38fffe6d57..e01d4bc043f46c4da7813a36beeb67dbef489ba2 100644 (file)
@@ -32,7 +32,7 @@ USDA_XZ_SHA256 = "25dba8428ced42d646bec704981d3a95dc7943240254e884aad37d59eee961
 # Global variables
 PROJECT_ROOT = os.path.abspath(os.path.dirname(__file__))
 NUTRA_HOME = os.getenv("NUTRA_HOME", os.path.join(os.path.expanduser("~"), ".nutra"))
-USDA_DB_NAME = "usda.sqlite"
+USDA_DB_NAME = "usda.sqlite3"
 # NOTE: NT_DB_NAME = "nt.sqlite3" is defined in ntclient.ntsqlite.sql
 
 NTSQLITE_BUILDPATH = os.path.join(PROJECT_ROOT, "ntsqlite", "sql", NT_DB_NAME)
index 2e4389e168e1e214f0794fb17ad6efeca43fe44a..24db797def4e49167393b1807d35819c6b8e78da 100644 (file)
@@ -352,27 +352,46 @@ def bug_simulate(args: argparse.Namespace) -> tuple:
 
 def bugs_list(args: argparse.Namespace) -> tuple:
     """List bug reports that have been saved"""
-    _bugs_list = ntclient.services.bugs.list_bugs()
-    n_bugs_total = len(_bugs_list)
-    n_bugs_unsubmitted = len([x for x in _bugs_list if not bool(x[-1])])
+    rows, headers = ntclient.services.bugs.list_bugs()
+    n_bugs_total = len(rows)
+    n_bugs_unsubmitted = len([x for x in rows if not bool(x[-1])])
 
     print(f"You have: {n_bugs_total} total bugs amassed in your journey.")
     print(f"Of these, {n_bugs_unsubmitted} require submission/reporting.")
     print()
 
-    for bug in _bugs_list:
+    print(rows)
+    # print([[entry for entry in row] for row in rows])
+    # exit(0)
+    table = tabulate(
+        [[entry for entry in row if "\n" not in str(entry)] for row in rows],
+        headers=headers,
+        tablefmt="presto",
+    )
+    print(table)
+    exit(0)
+
+    for bug in rows:
         if not args.show:
             continue
         # Skip submitted bugs by default
         if bool(bug[-1]) and not args.debug:
             continue
         # Print all bug properties (except noisy stacktrace)
-        print(", ".join(str(x) for x in bug if "\n" not in str(x)))
+        bug_line = str()
+        for _col_name, _value in dict(bug).items():
+            print(_col_name)
+            print(_value)
+            if "\n" in str(_value):
+                continue
+            bug_line += str(_value) + ", "
+        # print(", ".join(str(x) for x in bug if "\n" not in str(x)))
+        print()
 
     if n_bugs_unsubmitted > 0:
         print("NOTE: You have bugs awaiting submission.  Please run the report command")
 
-    return 0, _bugs_list
+    return 0, rows
 
 
 # pylint: disable=unused-argument
index 20030d91063e66e317631048d62fe73f6acb11c3..18d28da01e4bc3c8a4fb6bcca006d6015e106b0c 100644 (file)
@@ -2,6 +2,7 @@
 
 import sqlite3
 from collections.abc import Sequence
+from typing import Optional
 
 from ntclient.utils import CLI_CONFIG
 
@@ -10,12 +11,24 @@ from ntclient.utils import CLI_CONFIG
 # ------------------------------------------------
 
 
-def sql_entries(sql_result: sqlite3.Cursor) -> list:
-    """Formats and returns a `sql_result()` for console digestion and output"""
-    # TODO: return object: metadata, command, status, errors, etc?
+def sql_entries(sql_result: sqlite3.Cursor) -> tuple[list, list, int, Optional[int]]:
+    """
+    Formats and returns a `sql_result()` for console digestion and output
+    FIXME: the IDs are not necessarily integers, but are unique.
 
-    rows = sql_result.fetchall()
-    return rows
+    TODO: return object: metadata, command, status, errors, etc?
+    """
+
+    return (
+        # rows
+        sql_result.fetchall(),
+        # headers
+        [x[0] for x in sql_result.description],
+        # row_count
+        sql_result.rowcount,
+        # last_row_id
+        sql_result.lastrowid,
+    )
 
 
 def sql_entries_headers(sql_result: sqlite3.Cursor) -> tuple:
@@ -92,7 +105,7 @@ def _sql(
     query: str,
     db_name: str,
     values: Sequence = (),
-) -> list:
+) -> tuple[list, list, int, Optional[int]]:
     """@param values: tuple | list"""
 
     cur = _prep_query(con, query, db_name, values)
@@ -103,19 +116,3 @@ def _sql(
 
     close_con_and_cur(con, cur)
     return result
-
-
-def _sql_headers(
-    con: sqlite3.Connection,
-    query: str,
-    db_name: str,
-    values: Sequence = (),
-) -> tuple:
-    """@param values: tuple | list"""
-
-    cur = _prep_query(con, query, db_name, values)
-
-    result = sql_entries_headers(cur)
-
-    close_con_and_cur(con, cur)
-    return result
index a920991c2f08b50456ddf7a31f1bcbdb66676d0d..ef19816a4982a39ec8bf3b0f72c079e0fce98344 100644 (file)
@@ -3,6 +3,7 @@
 import os
 import sqlite3
 from collections.abc import Sequence
+from typing import Optional
 
 from ntclient import (
     NT_DB_NAME,
@@ -11,7 +12,7 @@ from ntclient import (
     NUTRA_HOME,
     __db_target_nt__,
 )
-from ntclient.persistence.sql import _sql, _sql_headers, version
+from ntclient.persistence.sql import _sql, version
 from ntclient.utils.exceptions import SqlConnectError, SqlInvalidVersionError
 
 
@@ -80,15 +81,8 @@ def nt_sqlite_connect(version_check: bool = True) -> sqlite3.Connection:
     raise SqlConnectError("ERROR: nt database doesn't exist, please run `nutra init`")
 
 
-def sql(query: str, values: Sequence = ()) -> list:
+def sql(query: str, values: Sequence = ()) -> tuple[list, list, int, Optional[int]]:
     """Executes a SQL command to nt.sqlite3"""
 
     con = nt_sqlite_connect()
     return _sql(con, query, db_name="nt", values=values)
-
-
-def sql_headers(query: str, values: Sequence = ()) -> tuple:
-    """Executes a SQL command to nt.sqlite3"""
-
-    con = nt_sqlite_connect()
-    return _sql_headers(con, query, db_name="nt", values=values)
index 06d2dff8744e2f0c4d591e2727b9da686e794172..d2cd4626c952c6fc8e1a4df9b92579af36ee31bf 100644 (file)
@@ -5,6 +5,8 @@ from ntclient.persistence.sql.nt import sql
 
 def sql_nt_next_index(table: str) -> int:
     """Used for previewing inserts"""
+    # TODO: parameterized queries
     # noinspection SqlResolve
     query = "SELECT MAX(id) as max_id FROM %s;" % table  # nosec: B608
-    return int(sql(query)[0]["max_id"])
+    rows, _, _, _ = sql(query)
+    return int(rows[0]["max_id"])
index e36e976992610e483cbc1d5ba0aaa742c83e4a21..419494eb158867a9e1cea8f08401e5865e0fb1a3 100644 (file)
@@ -5,9 +5,10 @@ import sqlite3
 import tarfile
 import urllib.request
 from collections.abc import Sequence
+from typing import Optional
 
 from ntclient import NUTRA_HOME, USDA_DB_NAME, __db_target_usda__
-from ntclient.persistence.sql import _sql, _sql_headers, version
+from ntclient.persistence.sql import _sql, version
 from ntclient.utils.exceptions import SqlConnectError, SqlInvalidVersionError
 
 
@@ -98,7 +99,9 @@ def usda_ver() -> str:
     return version(con)
 
 
-def sql(query: str, values: Sequence = (), version_check: bool = True) -> list:
+def sql(
+    query: str, values: Sequence = (), version_check: bool = True
+) -> tuple[list, list, int, Optional[int]]:
     """
     Executes a SQL command to usda.sqlite3
 
@@ -114,21 +117,3 @@ def sql(query: str, values: Sequence = (), version_check: bool = True) -> list:
 
     # TODO: support argument: _sql(..., params=params, ...)
     return _sql(con, query, db_name="usda", values=values)
-
-
-def sql_headers(query: str, values: Sequence = (), version_check: bool = True) -> tuple:
-    """
-    Executes a SQL command to usda.sqlite3 [WITH HEADERS]
-
-    @param query: Input SQL query
-    @param values: Union[tuple, list] Leave as empty tuple for no values,
-        e.g. bare query. Populate a tuple for a single insert. And use a list for
-        cur.executemany()
-    @param version_check: Ignore mismatch version, useful for "meta" commands
-    @return: List of selected SQL items
-    """
-
-    con = usda_sqlite_connect(version_check=version_check)
-
-    # TODO: support argument: _sql(..., params=params, ...)
-    return _sql_headers(con, query, db_name="usda", values=values)
index 34422325cac7fe4181203377adc7362351666625..aa08701cef6aadbbd70899c38f96512240c37a3c 100644 (file)
@@ -1,7 +1,7 @@
 """usda.sqlite functions module"""
 
 from ntclient import NUTR_ID_KCAL
-from ntclient.persistence.sql.usda import sql, sql_headers
+from ntclient.persistence.sql.usda import sql
 
 
 ################################################################################
@@ -11,8 +11,8 @@ def sql_fdgrp() -> dict:
     """Shows food groups"""
 
     query = "SELECT * FROM fdgrp;"
-    result = sql(query)
-    return {x[0]: x for x in result}
+    rows, _, _, _ = sql(query)
+    return {x[0]: x for x in rows}
 
 
 def sql_food_details(_food_ids: set = None) -> list:  # type: ignore
@@ -26,22 +26,24 @@ def sql_food_details(_food_ids: set = None) -> list:  # type: ignore
         food_ids = ",".join(str(x) for x in set(_food_ids))
         query = query % food_ids
 
-    return sql(query)
+    rows, _, _, _ = sql(query)
+    return rows
 
 
 def sql_nutrients_overview() -> dict:
     """Shows nutrients overview"""
 
     query = "SELECT * FROM nutrients_overview;"
-    result = sql(query)
-    return {x[0]: x for x in result}
+    rows, _, _, _ = sql(query)
+    return {x[0]: x for x in rows}
 
 
 def sql_nutrients_details() -> tuple:
     """Shows nutrients 'details'"""
 
     query = "SELECT * FROM nutrients_overview;"
-    return sql_headers(query)
+    rows, headers, _, _ = sql(query)
+    return rows, headers
 
 
 def sql_servings(_food_ids: set) -> list:
@@ -61,7 +63,8 @@ WHERE
 """
     # FIXME: support this kind of thing by library code & parameterized queries
     food_ids = ",".join(str(x) for x in set(_food_ids))
-    return sql(query % food_ids)
+    rows, _, _, _ = sql(query % food_ids)
+    return rows
 
 
 def sql_analyze_foods(food_ids: set) -> list:
@@ -79,7 +82,8 @@ WHERE
 """
     # TODO: parameterized queries
     food_ids_concat = ",".join(str(x) for x in set(food_ids))
-    return sql(query % food_ids_concat)
+    rows, _, _, _ = sql(query % food_ids_concat)
+    return rows
 
 
 ################################################################################
@@ -101,8 +105,9 @@ WHERE
 ORDER BY
   food_id;
 """
-
-    return sql(query % (NUTR_ID_KCAL, nutrient_id))
+    # TODO: parameterized queries
+    rows, _, _, _ = sql(query % (NUTR_ID_KCAL, nutrient_id))
+    return rows
 
 
 def sql_sort_foods(nutr_id: int) -> list:
@@ -127,8 +132,9 @@ WHERE
 ORDER BY
   nut_data.nutr_val DESC;
 """
-
-    return sql(query % nutr_id)
+    # TODO: parameterized queries
+    rows, _, _, _ = sql(query % nutr_id)
+    return rows
 
 
 def sql_sort_foods_by_kcal(nutr_id: int) -> list:
@@ -156,5 +162,6 @@ WHERE
 ORDER BY
   (nut_data.nutr_val / kcal.nutr_val) DESC;
 """
-
-    return sql(query % nutr_id)
+    # TODO: parameterized queries
+    rows, _, _, _ = sql(query % nutr_id)
+    return rows
index b91b2f754af8074afac85ce53bfd10771efaddd1..b408bb74e814a5770b26d78c75b6be7b36e8eeb9 100644 (file)
@@ -49,8 +49,8 @@ INSERT INTO bug
                         "version_usda_db_target": __db_target_usda__,
                     }
                 ),
-                # user_details
-                "NOT_IMPLEMENTED",
+                # user_details (TODO: add user details)
+                None,
             ),
         )
     except sqlite3.IntegrityError as exc:
@@ -63,24 +63,24 @@ INSERT INTO bug
             raise
 
 
-def list_bugs() -> list:
-    """List all bugs."""
-    sql_bugs = sql_nt("SELECT * FROM bug")
-    return sql_bugs
+def list_bugs() -> tuple[list, list]:
+    """List all bugs, with headers."""
+    rows, headers, _, _ = sql_nt("SELECT * FROM bug")
+    return rows, headers
 
 
 def submit_bugs() -> int:
     """Submit bug reports to developer, return n_submitted."""
 
     # Gather bugs for submission
-    sql_bugs = sql_nt("SELECT * FROM bug WHERE submitted = 0")
+    rows, _, _, _ = sql_nt("SELECT * FROM bug WHERE submitted = 0")
     api_client = ntclient.services.api.ApiClient()
 
     n_submitted = 0
-    print(f"submitting {len(sql_bugs)} bug reports...")
-    print("_" * len(sql_bugs))
+    print(f"submitting {len(rows)} bug reports...")
+    print("_" * len(rows))
 
-    for bug in sql_bugs:
+    for bug in rows:
         _res = api_client.post_bug(bug)
         if CLI_CONFIG.debug:
             print(_res.json())
index b6face29ff726c5de7db9625900cf23d4d20e58f..d15736dc557dde0904808ce44f17c23040034f6e 100644 (file)
@@ -30,7 +30,7 @@ from ntclient.utils import CLI_CONFIG
 def list_nutrients() -> tuple:
     """Lists out nutrients with basic details"""
 
-    headers, nutrients = sql_nutrients_details()
+    nutrients, headers = sql_nutrients_details()
     # TODO: include in SQL table cache?
     headers.append("avg_rda")
     nutrients = [list(x) for x in nutrients]