1.1.7. ML-Model-Provider Node¶
The ML-Model-Provider Node is responsible for estimating the carbon footprint of the machine learning model on the selected hardware.
1.1.7.1. Inputs and Outputs¶
The following table summarizes the inputs and outputs of the ML-Model-Provider Node:
Input |
From Node |
Output |
To Node |
1.1.7.2. Node Template¶
Following is the Python API provided for the ML-Model-Provider Node.
User is meant to implement the funcionality of the node within the test:callback().
And inside configuration_callback() implement the response to the configuration request from the orchestrator.
# Copyright 2023 SustainML Consortium
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""SustainML ML Model Provider Node Implementation."""
from sustainml_py.nodes.MLModelNode import MLModelNode
# Manage signaling
import os
import signal
import threading
import time
import json
from rdftool.ModelONNXCodebase import model
from rdftool.rdfCode import load_graph, get_models_for_problem, get_models_for_problem_and_tag
from rag.rag_backend import answer_question
# Whether to go on spinning or interrupt
running = False
# Load the list of unsupported
def load_unsupported_models(file_path):
try:
with open(file_path, 'r') as f:
return [line.strip().lower() for line in f if line.strip()]
except Exception as e:
print(f"[WARN] Could not load unsupported list: {e}")
return []
unsupported_models = load_unsupported_models(os.path.dirname(__file__) + "/unsupported_models.txt")
# Signal handler
def signal_handler(sig, frame):
print("\nExiting")
MLModelNode.terminate()
global running
running = False
# User Callback implementation
# Inputs: ml_model_metadata, app_requirements, hw_constraints, ml_model_baseline, hw_baseline, carbonfootprint_baseline
# Outputs: node_status, ml_model
def task_callback(ml_model_metadata,
app_requirements,
hw_constraints,
ml_model_baseline,
hw_baseline,
carbonfootprint_baseline,
node_status,
ml_model):
# Callback implementation here
print(f"Received Task: {ml_model_metadata.task_id().problem_id()},{ml_model_metadata.task_id().iteration_id()}")
try:
chosen_model = None
# Model restriction after various outputs
restrained_models = []
type = None
extra_data_bytes = ml_model_metadata.extra_data()
if extra_data_bytes:
extra_data_str = ''.join(chr(b) for b in extra_data_bytes)
try:
extra_data_dict = json.loads(extra_data_str)
except json.JSONDecodeError:
print("[WARN] In model_provider node extra_data JSON is not valid.")
extra_data_dict = {}
if "type" in extra_data_dict:
type = extra_data_dict["type"]
if "model_restrains" in extra_data_dict:
restrained_models = extra_data_dict["model_restrains"]
if "model_selected" in extra_data_dict:
chosen_model = extra_data_dict["model_selected"]
print("Model already selected: ", chosen_model)
problem_short_description = extra_data_dict["problem_short_description"]
metadata = ml_model_metadata.ml_model_metadata()[0]
if chosen_model is None:
print(f"Problem short description: {problem_short_description}")
# Choose model with the RAG based on the goal selected and the knowledge of the graph.
chosen_model = answer_question(
f"Task {metadata} with problem description: {problem_short_description}?"
)
print(f"ML Model chosen: {chosen_model}")
# Generate model code and keywords
onnx_path = model(chosen_model) # TODO - Further development needed
ml_model.model(chosen_model)
ml_model.model_path(onnx_path)
# Add unsupported_models information to extra_data in json format
extra_data = {"unsupported_models": unsupported_models}
encoded_data = json.dumps(extra_data).encode("utf-8")
ml_model.extra_data(encoded_data)
except Exception as e:
print(f"Failed to determine ML model for task {ml_model_metadata.task_id()}: {e}.")
ml_model.model("Error")
ml_model.model_path("Error")
error_message = "Failed to obtain ML model for task: " + str(e)
error_info = {"error": error_message}
encoded_error = json.dumps(error_info).encode("utf-8")
ml_model.extra_data(encoded_error)
# User Configuration Callback implementation
# Inputs: req
# Outputs: res
def configuration_callback(req, res):
# Callback for configuration implementation here
if 'model_from_goal' in req.configuration():
res.node_id(req.node_id())
res.transaction_id(req.transaction_id())
try:
text = req.configuration()[len("model_from_goal, "):]
parts = text.split(',')
if len(parts) >= 2:
goal = parts[0].strip()
tag = parts[1].strip()
models = get_models_for_problem_and_tag(goal, tag)
else:
goal = text.strip()
models = get_models_for_problem(goal)
sorted_models = ', '.join(sorted([str(m[0]) for m in models]))
if not sorted_models:
res.success(False)
res.err_code(1) # 0: No error || 1: Error
else:
res.success(True)
res.err_code(0) # 0: No error || 1: Error
print(f"Models for {goal}: {sorted_models}") # debug
res.configuration(json.dumps(dict(models=sorted_models)))
except Exception as e:
print(f"Error getting model from goal from request: {e}")
res.success(False)
res.err_code(1)
else:
res.node_id(req.node_id())
res.transaction_id(req.transaction_id())
error_msg = f"Unsupported configuration request: {req.configuration()}"
res.configuration(json.dumps({"error": error_msg}))
res.success(False)
res.err_code(1) # 0: No error || 1: Error
print(error_msg)
# Main workflow routine
def run():
start_time = time.time()
loaded = False
while time.time() - start_time < 5:
if load_graph():
loaded = True
break
time.sleep(0.1)
if not loaded:
print("[Error] Graph not available")
exit(1)
node = MLModelNode(callback=task_callback, service_callback=configuration_callback)
global running
running = True
node.spin()
# Call main in program execution
if __name__ == '__main__':
signal.signal(signal.SIGINT, signal_handler)
"""Python does not process signals async if
the main thread is blocked (spin()) so, tun
user work flow in another thread """
runner = threading.Thread(target=run)
runner.start()
while running:
time.sleep(1)
runner.join()