]> Nutra Git (v2) - nutratech/cli.git/commitdiff
refactor re-used code into compact method calls
authorShane Jaroch <chown_tee@proton.me>
Sun, 11 Jan 2026 10:03:21 +0000 (05:03 -0500)
committerShane Jaroch <chown_tee@proton.me>
Sun, 11 Jan 2026 10:03:21 +0000 (05:03 -0500)
ntclient/models/__init__.py
ntclient/persistence/sql/__init__.py
ntclient/persistence/sql/nt/__init__.py
ntclient/persistence/sql/usda/__init__.py
ntclient/services/analyze.py
ntclient/services/calculate.py
tests/test_cli.py

index 5280b171326dd253b4ecc4c6e3a66c35f5c07a31..5f2ba54c43f28aef25f08fa20e7ee2f7730504ff 100644 (file)
@@ -7,6 +7,13 @@ Classes, structures for storing, displaying, and editing data.
 """
 import csv
 
+from ntclient import BUFFER_WD
+from ntclient.persistence.sql.usda.funcs import (
+    sql_analyze_foods,
+    sql_nutrients_overview,
+)
+from ntclient.services.analyze import day_format
+from ntclient.services.calculate import calculate_nutrient_totals
 from ntclient.utils import CLI_CONFIG
 
 
@@ -61,12 +68,6 @@ class Recipe:
 
     def print_analysis(self, scale: float = 0, scale_mode: str = "kcal") -> None:
         """Run analysis on a single recipe"""
-        from ntclient import BUFFER_WD
-        from ntclient.persistence.sql.usda.funcs import (
-            sql_analyze_foods,
-            sql_nutrients_overview,
-        )
-        from ntclient.services.analyze import day_format
 
         # Get nutrient overview (RDAs, units, etc.)
         nutrients_rows = sql_nutrients_overview()
@@ -85,20 +86,9 @@ class Recipe:
                 foods_analysis[food_id].append(anl)
 
         # Compute totals
-        nutrient_totals = {}
-        total_weight = 0.0
-        for food_id, grams in self.food_data.items():
-            total_weight += grams
-            if food_id not in foods_analysis:
-                continue
-            for _nutrient in foods_analysis[food_id]:
-                nutr_id = _nutrient[0]
-                nutr_per_100g = _nutrient[1]
-                nutr_val = grams / 100 * nutr_per_100g
-                if nutr_id not in nutrient_totals:
-                    nutrient_totals[nutr_id] = nutr_val
-                else:
-                    nutrient_totals[nutr_id] += nutr_val
+        nutrient_totals, total_weight = calculate_nutrient_totals(
+            self.food_data, foods_analysis
+        )
 
         # Print results using day_format for consistency
         buffer = BUFFER_WD - 4 if BUFFER_WD > 4 else BUFFER_WD
index 8e2e40ffa6ca971789d1ed48440b938adbdc1d0d..285a2eb35d921257b7cb4ae28fda21799a478733 100644 (file)
@@ -1,7 +1,7 @@
 """Main SQL persistence module, shared between USDA and NT databases"""
 
 import sqlite3
-from collections.abc import Sequence
+from collections.abc import Sequence  # pylint: disable=import-error
 
 from ntclient.utils import CLI_CONFIG
 
index 5c711c9e3dc542b0e42f6a2db0c01391ffa127cd..10c1a3023562c2138ad3f51bf2ae90d4265c6bed 100644 (file)
@@ -2,7 +2,7 @@
 
 import os
 import sqlite3
-from collections.abc import Sequence
+from collections.abc import Sequence  # pylint: disable=import-error
 
 from ntclient import (
     NT_DB_NAME,
index 4434a411dd5dcb6af99a5d9b7f261441bb31de7f..c16d8d3dfe1ce65971dd9892616ac303bb530fd0 100644 (file)
@@ -4,7 +4,7 @@ import os
 import sqlite3
 import tarfile
 import urllib.request
-from collections.abc import Sequence
+from collections.abc import Sequence  # pylint: disable=import-error
 
 from ntclient import NUTRA_HOME, USDA_DB_NAME, __db_target_usda__
 from ntclient.persistence.sql import _sql, version
index b09f0c014ec1b175e1a56510776b51fe534e0281..9cdb48c06a7f957cb579e2863b5274e1cee72e83 100644 (file)
@@ -196,22 +196,23 @@ def day_analyze(
     # Compute totals
     nutrients_totals = []
     total_grams_list = []
+    from ntclient.services.calculate import calculate_nutrient_totals
+
     for log in logs:
-        nutrient_totals = OrderedDict()  # NOTE: dict()/{} is NOT ORDERED before 3.6/3.7
-        daily_grams = 0.0
+        # Aggregate duplicates in log if any
+        food_data: OrderedDict[int, float] = OrderedDict()
         for entry in log:
             if entry["id"]:
-                food_id = int(entry["id"])
-                grams = float(entry["grams"])
-                daily_grams += grams
-                for _nutrient2 in foods_analysis[food_id]:
-                    nutr_id = _nutrient2[0]
-                    nutr_per_100g = _nutrient2[1]
-                    nutr_val = grams / 100 * nutr_per_100g
-                    if nutr_id not in nutrient_totals:
-                        nutrient_totals[nutr_id] = nutr_val
-                    else:
-                        nutrient_totals[nutr_id] += nutr_val
+                f_id = int(entry["id"])
+                f_grams = float(entry["grams"])
+                if f_id in food_data:
+                    food_data[f_id] += f_grams
+                else:
+                    food_data[f_id] = f_grams
+
+        nutrient_totals, daily_grams = calculate_nutrient_totals(
+            food_data, foods_analysis
+        )
         nutrients_totals.append(nutrient_totals)
         total_grams_list.append(daily_grams)
 
@@ -239,40 +240,15 @@ def day_format(
 ) -> None:
     """Formats day analysis for printing to console"""
 
-    multiplier = 1.0
-    if scale:
-        if scale_mode == "kcal":
-            current_val = analysis.get(NUTR_ID_KCAL, 0)
-            multiplier = scale / current_val if current_val else 0
-        elif scale_mode == "weight":
-            multiplier = scale / total_weight if total_weight else 0
-        else:
-            # Try to interpret scale_mode as nutrient ID or Name
-            target_id = None
-            # 1. Check if int
-            try:
-                target_id = int(scale_mode)
-            except ValueError:
-                # 2. Check names
-                for n_id, n_data in nutrients.items():
-                    # n_data usually: (id, rda, unit, tag, name, ...)
-                    # Check tag or desc
-                    if scale_mode.lower() in str(n_data[3]).lower():
-                        target_id = n_id
-                        break
-                    if scale_mode.lower() in str(n_data[4]).lower():
-                        target_id = n_id
-                        break
-
-            if target_id and target_id in analysis:
-                current_val = analysis[target_id]
-                multiplier = scale / current_val if current_val else 0
-            else:
-                print(f"WARN: Could not scale by '{scale_mode}', nutrient not found.")
+    from ntclient.services.calculate import calculate_scaling_multiplier
+
+    multiplier = calculate_scaling_multiplier(
+        scale, scale_mode, analysis, nutrients, total_weight
+    )
 
-        # Apply multiplier
-        if multiplier != 1.0:
-            analysis = {k: v * multiplier for k, v in analysis.items()}
+    # Apply multiplier
+    if multiplier != 1.0:
+        analysis = {k: v * multiplier for k, v in analysis.items()}
 
     # Actual values
     kcals = round(analysis.get(NUTR_ID_KCAL, 0))
index 72a3d2c884c56299a5f77db8a278cf81a9e924b3..1b3f312d38f32b95e80bd7535f2d9a9e3b596eb7 100644 (file)
@@ -8,6 +8,8 @@ Created on Tue Aug 11 20:53:14 2020
 """
 import argparse
 import math
+from collections import OrderedDict
+from typing import Mapping
 
 from ntclient.utils import Gender
 
@@ -511,3 +513,82 @@ def lbl_casey_butt(height: float, args: argparse.Namespace) -> tuple:
         # calf
         round(0.9812 * ankle + 0.1250 * height, 2),
     )
+
+
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+# Nutrient Aggregation
+# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+def calculate_nutrient_totals(
+    food_data: Mapping[int, float], foods_analysis: Mapping[int, list]
+) -> tuple[OrderedDict, float]:
+    """
+    Common logic to aggregate nutrient data for a list of foods.
+
+    @param food_data: dict of {food_id: grams, ...}
+    @param foods_analysis: dict of {food_id: [(nutr_id, val_per_100g), ...], ...}
+    @return: (nutrient_totals, total_grams)
+    """
+    nutrient_totals = OrderedDict()
+    total_grams = 0.0
+
+    for food_id, grams in food_data.items():
+        total_grams += grams
+        if food_id not in foods_analysis:
+            continue
+        for _nutrient in foods_analysis[food_id]:
+            nutr_id = _nutrient[0]
+            nutr_per_100g = _nutrient[1]
+            nutr_val = grams / 100 * nutr_per_100g
+            if nutr_id not in nutrient_totals:
+                nutrient_totals[nutr_id] = nutr_val
+            else:
+                nutrient_totals[nutr_id] += nutr_val
+
+    return nutrient_totals, total_grams
+
+
+def calculate_scaling_multiplier(
+    scale: float,
+    scale_mode: str,
+    analysis: Mapping,
+    nutrients: Mapping,
+    total_weight: float,
+) -> float:
+    """
+    Determine the multiplier needed to scale the analysis values.
+    """
+    multiplier = 1.0
+    from ntclient import NUTR_ID_KCAL
+
+    if not scale:
+        return multiplier
+
+    if scale_mode == "kcal":
+        current_val = analysis.get(NUTR_ID_KCAL, 0)
+        multiplier = scale / current_val if current_val else 0
+    elif scale_mode == "weight":
+        multiplier = scale / total_weight if total_weight else 0
+    else:
+        # Try to interpret scale_mode as nutrient ID or Name
+        target_id = None
+        # 1. Check if int
+        try:
+            target_id = int(scale_mode)
+        except ValueError:
+            # 2. Check names
+            for n_id, n_data in nutrients.items():
+                # n_data usually: (id, rda, unit, tag, name, ...)
+                if scale_mode.lower() in str(n_data[3]).lower():
+                    target_id = n_id
+                    break
+                if scale_mode.lower() in str(n_data[4]).lower():
+                    target_id = n_id
+                    break
+
+        if target_id and target_id in analysis:
+            current_val = analysis[target_id]
+            multiplier = scale / current_val if current_val else 0
+        else:
+            print(f"WARN: Could not scale by '{scale_mode}', nutrient not found.")
+
+    return multiplier
index 04870c592378102ca6b9b57a01d141245e3a23d2..3a768d4629a4400bd50d05527e66d630efe867ea 100644 (file)
@@ -430,7 +430,8 @@ class TestCli(unittest.TestCase):
             pytest.xfail("PermissionError, are you using Microsoft Windows?")
 
         # mocks input, could also pass `-y` flag or set yes=True
-        usda.input = lambda x: "y"  # pylint: disable=redefined-builtin
+        # pylint: disable=redefined-builtin
+        usda.input = lambda x: "y"
 
         code, successful = init()
         assert code == 0