1.1.7. ML-Model-Provider Node

Orchestrator (Back-end) Orchestrator (Back-end) ML Model Metadata Node ML Model Metadata Node CO2 footprint CO2 footprint HW Constraints Node Carbontracker Node Carbontracker Node HW Constraints HW Constraints HW Resource HW Resource ML Model ML Model User input data User input data ML Model ML Model HW Resource HW Resource ML Metadata ML Metadata Baseline forOptimization Application-levelRequirements Node User input data User input data User input data User input data App Requirements App Requirements CO2 footprint CO2 footprint Front-end Front-end User input data User input data Output data Output data User User Model Provider Node ML Solution Provider ML Optimization HW Provider Node FPGA Selector... PIM Results

The ML-Model-Provider Node is responsible for estimating the carbon footprint of the machine learning model on the selected hardware.

1.1.7.1. Inputs and Outputs

The following table summarizes the inputs and outputs of the ML-Model-Provider Node:

1.1.7.2. Node Template

Following is the Python API provided for the ML-Model-Provider Node. User is meant to implement the funcionality of the node within the test:callback(). And inside configuration_callback() implement the response to the configuration request from the orchestrator.

# Copyright 2023 SustainML Consortium
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""SustainML ML Model Provider Node Implementation."""

from sustainml_py.nodes.MLModelNode import MLModelNode

# Manage signaling
import os
import signal
import threading
import time
import json

from rdftool.ModelONNXCodebase import model
from rdftool.rdfCode import load_graph, get_models_for_problem, get_models_for_problem_and_tag

from rag.rag_backend import answer_question

# Whether to go on spinning or interrupt
running = False


# Load the list of unsupported
def load_unsupported_models(file_path):
    try:
        with open(file_path, 'r') as f:
            return [line.strip().lower() for line in f if line.strip()]
    except Exception as e:
        print(f"[WARN] Could not load unsupported list: {e}")
        return []


unsupported_models = load_unsupported_models(os.path.dirname(__file__) + "/unsupported_models.txt")


# Signal handler
def signal_handler(sig, frame):
    print("\nExiting")
    MLModelNode.terminate()
    global running
    running = False


# User Callback implementation
# Inputs: ml_model_metadata, app_requirements, hw_constraints, ml_model_baseline, hw_baseline, carbonfootprint_baseline
# Outputs: node_status, ml_model
def task_callback(ml_model_metadata,
                  app_requirements,
                  hw_constraints,
                  ml_model_baseline,
                  hw_baseline,
                  carbonfootprint_baseline,
                  node_status,
                  ml_model):

    # Callback implementation here

    print(f"Received Task: {ml_model_metadata.task_id().problem_id()},{ml_model_metadata.task_id().iteration_id()}")

    try:
        chosen_model = None
        # Model restriction after various outputs
        restrained_models = []
        type = None
        extra_data_bytes = ml_model_metadata.extra_data()
        if extra_data_bytes:
            extra_data_str = ''.join(chr(b) for b in extra_data_bytes)
            try:
                extra_data_dict = json.loads(extra_data_str)
            except json.JSONDecodeError:
                print("[WARN] In model_provider node extra_data JSON is not valid.")
                extra_data_dict = {}

            if "type" in extra_data_dict:
                type = extra_data_dict["type"]

            if "model_restrains" in extra_data_dict:
                restrained_models = extra_data_dict["model_restrains"]

            if "model_selected" in extra_data_dict:
                chosen_model = extra_data_dict["model_selected"]
                print("Model already selected: ", chosen_model)

            problem_short_description = extra_data_dict["problem_short_description"]

        metadata = ml_model_metadata.ml_model_metadata()[0]

        if chosen_model is None:
            print(f"Problem short description: {problem_short_description}")

            # Choose model with the RAG based on the goal selected and the knowledge of the graph.
            chosen_model = answer_question(
                 f"Task {metadata} with problem description: {problem_short_description}?"
             )

        print(f"ML Model chosen: {chosen_model}")

        # Generate model code and keywords
        onnx_path = model(chosen_model)     # TODO - Further development needed
        ml_model.model(chosen_model)
        ml_model.model_path(onnx_path)
        # Add unsupported_models information to extra_data in json format
        extra_data = {"unsupported_models": unsupported_models}
        encoded_data = json.dumps(extra_data).encode("utf-8")
        ml_model.extra_data(encoded_data)

    except Exception as e:
        print(f"Failed to determine ML model for task {ml_model_metadata.task_id()}: {e}.")
        ml_model.model("Error")
        ml_model.model_path("Error")
        error_message = "Failed to obtain ML model for task: " + str(e)
        error_info = {"error": error_message}
        encoded_error = json.dumps(error_info).encode("utf-8")
        ml_model.extra_data(encoded_error)


# User Configuration Callback implementation
# Inputs: req
# Outputs: res
def configuration_callback(req, res):

    # Callback for configuration implementation here
    if 'model_from_goal' in req.configuration():
        res.node_id(req.node_id())
        res.transaction_id(req.transaction_id())

        try:
            text = req.configuration()[len("model_from_goal, "):]
            parts = text.split(',')
            if len(parts) >= 2:
                goal = parts[0].strip()
                tag = parts[1].strip()
                models = get_models_for_problem_and_tag(goal, tag)
            else:
                goal = text.strip()
                models = get_models_for_problem(goal)

            sorted_models = ', '.join(sorted([str(m[0]) for m in models]))

            if not sorted_models:
                res.success(False)
                res.err_code(1)  # 0: No error || 1: Error
            else:
                res.success(True)
                res.err_code(0)  # 0: No error || 1: Error

            print(f"Models for {goal}: {sorted_models}")  # debug
            res.configuration(json.dumps(dict(models=sorted_models)))

        except Exception as e:
            print(f"Error getting model from goal from request: {e}")
            res.success(False)
            res.err_code(1)

    else:
        res.node_id(req.node_id())
        res.transaction_id(req.transaction_id())
        error_msg = f"Unsupported configuration request: {req.configuration()}"
        res.configuration(json.dumps({"error": error_msg}))
        res.success(False)
        res.err_code(1) # 0: No error || 1: Error
        print(error_msg)


# Main workflow routine
def run():
    start_time = time.time()
    loaded = False
    while time.time() - start_time < 5:
        if load_graph():
            loaded = True
            break
        time.sleep(0.1)
    if not loaded:
        print("[Error] Graph not available")
        exit(1)
    node = MLModelNode(callback=task_callback, service_callback=configuration_callback)
    global running
    running = True
    node.spin()


# Call main in program execution
if __name__ == '__main__':
    signal.signal(signal.SIGINT, signal_handler)

    """Python does not process signals async if
    the main thread is blocked (spin()) so, tun
    user work flow in another thread """
    runner = threading.Thread(target=run)
    runner.start()

    while running:
        time.sleep(1)

    runner.join()