1.1.7. ML-Model-Provider NodeΒΆ

Orchestrator (Back-end) Orchestrator (Back-end) ML Model Metadata Node ML Model Metadata Node CO2 footprint CO2 footprint HW Constraints Node Carbontracker Node Carbontracker Node HW Constraints HW Constraints HW Resource HW Resource ML Model ML Model User input data User input data ML Model ML Model HW Resource HW Resource ML Metadata ML Metadata Baseline forOptimization Application-levelRequirements Node User input data User input data User input data User input data App Requirements App Requirements CO2 footprint CO2 footprint Front-end Front-end User input data User input data Output data Output data User User Model Provider Node ML Solution Provider ML Optimization HW Provider Node FPGA Selector... PIM Results

The ML-Model-Provider Node is responsible for estimating the carbon footprint of the machine learning model on the selected hardware.

1.1.7.1. Inputs and OutputsΒΆ

The following table summarizes the inputs and outputs of the ML-Model-Provider Node:

1.1.7.2. Node TemplateΒΆ

Following is the Python API provided for the ML-Model-Provider Node. User is meant to implement the funcionality of the node within the test:callback(). And inside configuration_callback() implement the response to the configuration request from the orchestrator.

# Copyright 2023 SustainML Consortium
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""SustainML ML Model Provider Node Implementation."""

import asyncio
import json
import os
import signal
import sys
import threading
import time

HERE = os.path.dirname(__file__)
WP2_ROOT = os.path.abspath(os.path.join(HERE, "..", "sustainml-wp2"))

if WP2_ROOT not in sys.path:
    sys.path.insert(0, WP2_ROOT)

import hw_provider_fpga

from rdftool.ModelONNXCodebase import model
from neo4j import GraphDatabase
from rdftool.rdfCode import load_graph, get_models_for_problem
from rag.rag_backend import answer_question
from os.path import isdir, dirname, abspath, join
from os import listdir
from sustainml_py.nodes.MLModelNode import MLModelNode
from fastmcp import Client


MCP_SERVER_SCRIPT = os.path.abspath(os.path.expanduser(
    "~/SustainML/SustainML_ws/src/sustainml_lib/sustainml_modules/sustainml_modules/sustainml-wp1/hf_mcp_server.py"
))

_mcp_loop = None
_mcp_thread = None
_mcp_client = None
_mcp_client_ctx = None


def _start_asyncio_loop():
    global _mcp_loop
    _mcp_loop = asyncio.new_event_loop()
    asyncio.set_event_loop(_mcp_loop)
    _mcp_loop.run_forever()


_mcp_thread = threading.Thread(target=_start_asyncio_loop, daemon=True)
_mcp_thread.start()


# Create (once) and return a persistent FastMCP client connected to the HF MCP server
async def _ensure_mcp_client():
    global _mcp_client, _mcp_client_ctx

    if _mcp_client is not None:
        return _mcp_client

    config = {
        "mcpServers": {
            "hf": {
                "command": sys.executable,
                "args": ["-u", MCP_SERVER_SCRIPT],
            }
        }
    }

    # Keep the context open forever
    _mcp_client_ctx = Client(config)
    _mcp_client = await _mcp_client_ctx.__aenter__()
    return _mcp_client


# Calls a Hugging Face MCP tool asynchronously from the ML Model Provider node
def _mcp_call(tool_name: str, args: dict) -> dict:
    async def _run():
        c = await _ensure_mcp_client()
        res = await c.call_tool(f"hf_{tool_name}", args)
        return json.loads(res.content[0].text)

    # Submit coroutine to background loop
    fut = asyncio.run_coroutine_threadsafe(_run(), _mcp_loop)
    return fut.result(timeout=120) # 2 minutes timeout for MCP call


# Neo4j config/driver for local checks (used by _model_has_goal)
NEO4J_URI = "bolt://localhost:7687"
NEO4J_USER = "neo4j"
NEO4J_PASSWORD = "12345678"
neo4j_driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USER, NEO4J_PASSWORD))


# Return the directory containing vendored U-Net ONNX models
def _unet_models_dir():
    env = os.environ.get("SUSTAINML_UNET_DIR")
    if env:
        return os.path.abspath(env)

    return os.path.abspath(
        os.path.join(os.path.dirname(hw_provider_fpga.__file__),
                    "vendor", "sustain_ml_predictor", "unet_models")
    )


# Lists available U-Net .onnx models from the vendor directory
def list_unet_onnx_models(basename_only=True):
    d = _unet_models_dir()
    if not os.path.isdir(d):
        print(f"[WARN] U-Net directory not found: {d}")
        return []
    names = [f for f in os.listdir(d) if f.endswith(".onnx")]
    return sorted([os.path.splitext(f)[0] for f in names]) if basename_only else \
        sorted([os.path.join(d, f) for f in names])


# Checks in Neo4j whether a model node is linked to the given goal/problem
def _model_has_goal(neo4j_driver, model_name: str, goal: str) -> bool:
    cypher = """
    MATCH (m:Model {name: $model})-[:HAS_PROBLEM]->(p:Problem)
    WHERE toLower(p.name) = toLower($goal)
    RETURN COUNT(*) AS cnt
    """
    with neo4j_driver.session() as s:
        r = s.run(cypher, model=model_name, goal=goal).single()
        return bool(r and r["cnt"] > 0)

# Whether to go on spinning or interrupt
running = False


# Load the list of unsupported models
def load_unsupported_models(file_path):
    try:
        with open(file_path, 'r') as f:
            return [line.strip().lower() for line in f if line.strip()]
    except Exception as e:
        print(f"[WARN] Could not load unsupported list: {e}")
        return []


unsupported_models = load_unsupported_models(os.path.dirname(__file__) + "/unsupported_models.txt")


# Close the FastMCP client context and reset MCP globals
def _shutdown_mcp():
    global _mcp_client_ctx, _mcp_client
    async def _close():
        global _mcp_client_ctx, _mcp_client
        if _mcp_client_ctx is not None:
            try:
                await _mcp_client_ctx.__aexit__(None, None, None)
            except Exception:
                pass
        _mcp_client_ctx = None
        _mcp_client = None

    if _mcp_loop:
        try:
            asyncio.run_coroutine_threadsafe(_close(), _mcp_loop).result(timeout=5)
        except Exception:
            pass

# Signal handler
def signal_handler(sig, frame):
    print("\nExiting")
    MLModelNode.terminate()
    global running
    running = False
    _shutdown_mcp()


# User Callback implementation
# Inputs: ml_model_metadata, app_requirements, hw_constraints, ml_model_baseline, hw_baseline, carbonfootprint_baseline
# Outputs: node_status, ml_model
def task_callback(ml_model_metadata,
                app_requirements,
                hw_constraints,
                ml_model_baseline,
                hw_baseline,
                carbonfootprint_baseline,
                node_status,
                ml_model):

    # Callback implementation here

    print(f"Received Task: {ml_model_metadata.task_id().problem_id()},{ml_model_metadata.task_id().iteration_id()}")

    try:
        chosen_model = None
        # Model restriction after various outputs
        restrained_models = []
        type = None
        extra_data_bytes = ml_model_metadata.extra_data()
        if extra_data_bytes:
            extra_data_str = ''.join(chr(b) for b in extra_data_bytes)
            try:
                extra_data_dict = json.loads(extra_data_str)
            except json.JSONDecodeError:
                print("[WARN] In model_provider node extra_data JSON is not valid.")
                extra_data_dict = {}

            if "type" in extra_data_dict:
                type = extra_data_dict["type"]

            if "model_restrains" in extra_data_dict:
                restrained_models = extra_data_dict["model_restrains"]

            if "model_selected" in extra_data_dict:
                chosen_model = extra_data_dict["model_selected"]
                print("Model already selected: ", chosen_model)

            # If a model was manually selected, skip automatic selection
            if chosen_model:

                # If user passed a full path to an .onnx, use it directly
                onnx_path = None
                if isinstance(chosen_model, str) and chosen_model.endswith(".onnx") and os.path.isfile(chosen_model):
                    onnx_path = chosen_model

                # Else, try U-Net vendor dir: <basename>[.onnx]
                if onnx_path is None:
                    unet_dir = _unet_models_dir()
                    cand1 = os.path.join(unet_dir, chosen_model)          # maybe includes .onnx already
                    cand2 = os.path.join(unet_dir, chosen_model + ".onnx")
                    if os.path.isfile(cand1):
                        onnx_path = cand1
                    elif os.path.isfile(cand2):
                        onnx_path = cand2

                # Else, fall back to graph name -> onnx path resolution
                if onnx_path is None:
                    onnx_path = model(chosen_model)

                ml_model.model(chosen_model)
                ml_model.model_path(onnx_path)

                # Add unsupported_models information to extra_data in JSON format
                extra_data = {"unsupported_models": unsupported_models}
                encoded_data = json.dumps(extra_data).encode("utf-8")
                ml_model.extra_data(encoded_data)
                return

            problem_short_description = extra_data_dict["problem_short_description"]

        goal = ml_model_metadata.ml_model_metadata()[0]  # Goal selected by metadata node
        print(f"Problem short description: {problem_short_description}")
        print(f"Selected goal (metadata): {goal}")

        # Build strictly goal-scoped allowed list (names only)
        goal_models = get_models_for_problem(goal)   # [(model_name, downloads), ...]
        allowed_names = [name for (name, _) in goal_models]
        if not allowed_names:
            raise Exception("No candidates in graph for the selected goal")

        # Track models to avoid repeats across outputs
        restrained_models = []
        if extra_data_bytes:
            try:
                if "model_restrains" in extra_data_dict:
                    restrained_models = list(set(extra_data_dict["model_restrains"]))
            except Exception:
                pass

        # Try up to 10 candidates, skipping misfits transparently
        chosen_model = None
        for _ in range(10):
            remaining = [n for n in allowed_names if n not in restrained_models]
            if not remaining:
                break

            candidate = answer_question(
                f"Task {goal} with problem description: {problem_short_description}?",
                allowed_models=remaining
            )

            if not candidate or candidate.strip().lower() == "none":
                # Mark and try again
                if candidate:
                    restrained_models.append(candidate)
                continue

            # Final safety: ensure candidate really belongs to goal
            if not _model_has_goal(neo4j_driver, candidate, goal):
                restrained_models.append(candidate)
                continue

            chosen_model = candidate
            break

        if not chosen_model:
            raise Exception("No suitable model after screening candidates")
        print(f"ML Model chosen: {chosen_model}")

        # Generate model code and keywords
        onnx_path = model(chosen_model)     # TODO - Further development needed
        ml_model.model(chosen_model)
        ml_model.model_path(onnx_path)
        # Add unsupported_models information to extra_data in json format
        extra_data = {"unsupported_models": unsupported_models}
        encoded_data = json.dumps(extra_data).encode("utf-8")
        ml_model.extra_data(encoded_data)

    except Exception as e:
        print(f"[WARN] No suitable model found for task {ml_model_metadata.task_id()}: {e}")
        ml_model.model("NO_MODEL")
        ml_model.model_path("N/A")
        error_message = "No suitable model found for the given problem."
        error_info = {"error_code": "NO_MODEL", "error": error_message}
        encoded_error = json.dumps(error_info).encode("utf-8")
        ml_model.extra_data(encoded_error)


# User Configuration Callback implementation
# Inputs: req
# Outputs: res
def configuration_callback(req, res):

    raw = (req.configuration() or "").strip()
    res.node_id(req.node_id())
    res.transaction_id(req.transaction_id())

    # HF search (metadata-only browsing; no evaluation)
    if raw.lower().startswith("hf_search"):
        # Expected: "hf_search, <description>, <limit>"
        rest = raw[len("hf_search"):].lstrip()
        if rest.startswith(","):
            rest = rest[1:].lstrip()

        # Split from the RIGHT so commas inside description are allowed
        description = rest
        limit = 20

        if "," in rest:
            desc_part, limit_part = rest.rsplit(",", 1)
            description = desc_part.strip()
            try:
                limit = int(limit_part.strip())
            except Exception:
                limit = 20
        else:
            description = rest.strip()

        try:
            resp = _mcp_call("search_models", {"description": description, "limit": limit})
            res.success(True)
            res.err_code(0)
            res.configuration(json.dumps(resp))
            return
        except Exception as e:
            res.success(False)
            res.err_code(1)
            res.configuration(json.dumps({"models": [], "error": str(e)}))
            return

    # Only handle the listing endpoint(s)
    if raw.lower().startswith("model_from_goal"):
        # Accept both forms:
        #   "model_from_goal, <goal>, <hw>, <family>"
        #   "model_from_goal, model_from_goal, <goal-or-sentinel>, <hw>, <family>" (buggy frontend)
        s = raw.split(",", 1)
        args_str = s[1] if len(s) > 1 else ""          # everything after first comma
        parts = [p.strip() for p in args_str.split(",") if p is not None]

        # Normalize arity (goal, hw, family)
        goal   = parts[0] if len(parts) >= 1 else ""
        hw     = parts[1] if len(parts) >= 2 else ""
        family = parts[2] if len(parts) >= 3 else ""

        fam_l = (family or "").lower()
        hw_l  = (hw or "").lower()
        is_cnn  = fam_l.lower() == "cnns"
        is_fpga = "fpga" in hw_l

        # U-Net fast path: allow sentinel goals like U_NET_MODELS or any goal when (FPGA+CNNs)
        if (goal.upper() == "U_NET_MODELS") or (is_cnn and is_fpga):
            try:
                try:
                    vendored = abspath(join(dirname(hw_provider_fpga.__file__),
                                            "vendor", "sustain_ml_predictor", "unet_models"))
                except Exception as e:
                    print(f"[ML_MODEL_PROVIDER] hw_provider_fpga not importable: {e}")
                    vendored = ""  # fall back to empty

                names = []
                if vendored and isdir(vendored):
                    names = [f[:-5] for f in listdir(vendored) if f.endswith(".onnx")]

                csv = ", ".join(sorted(names))
                res.success(True)
                res.err_code(0)
                res.configuration(json.dumps({"models": csv}))
                return
            except Exception as e:
                print(f"[ML_MODEL_PROVIDER] U-NET listing error: {e}")
                res.success(False)
                res.err_code(1)
                res.configuration(json.dumps({"models": ""}))
                return

        # Default path: fetch by real goal from Neo4j
        try:
            models = get_models_for_problem(goal)  # [(name, downloads)]
            names = [m for (m, _) in models]
            csv = ", ".join(sorted(names))
            res.success(bool(names))
            res.err_code(0 if names else 1)
            res.configuration(json.dumps({"models": csv}))
            return
        except Exception as e:
            print(f"[ML_MODEL_PROVIDER] Neo4j listing error: {e}")
            res.success(False)
            res.err_code(1)
            res.configuration(json.dumps({"models": ""}))
            return

    # Unsupported
    msg = f"Unsupported configuration request: {raw}"
    print(f"[ML_MODEL_PROVIDER] {msg}")
    res.success(False)
    res.err_code(1)
    res.configuration(json.dumps({"models": ""}))


# Main workflow routine
def run():
    start_time = time.time()
    loaded = False
    while time.time() - start_time < 5:
        if load_graph():
            loaded = True
            break
        time.sleep(0.1)
    if not loaded:
        print("[Error][ml_model_provider] Graph not available")
        exit(1)
    node = MLModelNode(callback=task_callback, service_callback=configuration_callback)
    global running
    running = True
    node.spin()


# Call main in program execution
if __name__ == '__main__':
    signal.signal(signal.SIGINT, signal_handler)

    """Python does not process signals async if
    the main thread is blocked (spin()) so, tun
    user work flow in another thread """
    runner = threading.Thread(target=run)
    runner.start()

    while running:
        time.sleep(1)

    runner.join()