Select Git revision
rosjava.cmake.em
optimize_main.py 2.22 KiB
# (c) Zizhe Wang
# https://zizhe.wang
############################
# #
# MAIN OPTIMIZATION SCRIPT #
# #
############################
import numpy as np
from joblib import Parallel, delayed
import matplotlib.pyplot as plt
from pymoo.core.problem import Problem
from pymoo.optimize import minimize
from optimization_libraries import initialize_algorithm
from parallel_computing import optimization_function, cleanup_temp_dirs
from config import (PARAMETERS, RESULTS, PARAM_BOUNDS, PRECISION, PLOT_CONFIG,
OPTIMIZATION_LIBRARY, ALGORITHM_NAME, POP_SIZE, N_GEN, N_JOBS) # Import all configuration variables
class OptimizationProblem(Problem):
def __init__(self):
self.param_names = list(PARAM_BOUNDS.keys())
n_var = len(self.param_names)
xl = np.array([PARAM_BOUNDS[param][0] for param in self.param_names])
xu = np.array([PARAM_BOUNDS[param][1] for param in self.param_names])
print(f"Number of variables: {n_var}")
print(f"Lower bounds: {xl}")
print(f"Upper bounds: {xu}")
super().__init__(n_var=n_var, n_obj=len(RESULTS), n_constr=0, xl=xl, xu=xu)
def _evaluate(self, X, out, *args, **kwargs):
param_values_list = [dict(zip(self.param_names, x)) for x in X]
results = Parallel(n_jobs=N_JOBS)(delayed(optimization_function)(param_values) for param_values in param_values_list)
out["F"] = np.array(results) # Ensure results are a 2D array
# Initialize the optimization algorithm
algorithm = initialize_algorithm(OPTIMIZATION_LIBRARY, ALGORITHM_NAME, POP_SIZE)
# Define the optimization problem
problem = OptimizationProblem()
# Run the optimization
res = minimize(problem, algorithm, ("n_gen", N_GEN), verbose=True)
# Cleanup temporary directories
cleanup_temp_dirs()
# Print the results
print("Optimization Results:")
for i, result in enumerate(res.F):
print(f"Solution {i}: ", end="")
for name, value in zip(RESULTS, result):
print(f"{name.capitalize()} = {value:.{PRECISION}f}", end=", ")
print()
# Plot the results
plt.figure(figsize=(8, 6))
plt.scatter(res.F[:, 0], res.F[:, 1])
plt.xlabel(PLOT_X)
plt.ylabel(PLOT_Y)
plt.title(PLOT_TITLE)
plt.grid(True)
plt.tight_layout()
plt.show()