Spaces:
Build error
Build error
| import streamlit as st | |
| import pandas as pd | |
| import json | |
| from scenario import Channel, Scenario | |
| import numpy as np | |
| from plotly.subplots import make_subplots | |
| import plotly.graph_objects as go | |
| from scenario import class_to_dict | |
| from collections import OrderedDict | |
| import io | |
| import plotly | |
| from pathlib import Path | |
| import pickle | |
| import yaml | |
| from yaml import SafeLoader | |
| from streamlit.components.v1 import html | |
| import smtplib | |
| from scipy.optimize import curve_fit | |
| from sklearn.metrics import r2_score | |
| from scenario import class_from_dict, class_convert_to_dict | |
| import os | |
| import base64 | |
| import sqlite3 | |
| import datetime | |
| from scenario import numerize | |
| import psycopg2 | |
| # | |
| import re | |
| import bcrypt | |
| import os | |
| import json | |
| import glob | |
| import pickle | |
| import streamlit as st | |
| import streamlit as st | |
| import pandas as pd | |
| import json | |
| from scenario import Channel, Scenario | |
| import numpy as np | |
| from plotly.subplots import make_subplots | |
| import plotly.graph_objects as go | |
| from scenario import class_to_dict | |
| from collections import OrderedDict | |
| import io | |
| import plotly | |
| from pathlib import Path | |
| import pickle | |
| import yaml | |
| from yaml import SafeLoader | |
| from streamlit.components.v1 import html | |
| import smtplib | |
| from scipy.optimize import curve_fit | |
| from sklearn.metrics import r2_score | |
| from scenario import class_from_dict, class_convert_to_dict | |
| import os | |
| import base64 | |
| import sqlite3 | |
| import datetime | |
| from scenario import numerize | |
| import sqlite3 | |
| # # schema = db_cred["schema"] | |
| color_palette = [ | |
| "#F3F3F0", | |
| "#5E7D7E", | |
| "#2FA1FF", | |
| "#00EDED", | |
| "#00EAE4", | |
| "#304550", | |
| "#EDEBEB", | |
| "#7FBEFD", | |
| "#003059", | |
| "#A2F3F3", | |
| "#E1D6E2", | |
| "#B6B6B6", | |
| ] | |
| CURRENCY_INDICATOR = "$" | |
| db_cred = None | |
| # database_file = r"DB/User.db" | |
| # conn = sqlite3.connect(database_file, check_same_thread=False) # connection with sql db | |
| # c = conn.cursor() | |
| # def query_excecuter_postgres( | |
| # query, | |
| # db_cred, | |
| # params=None, | |
| # insert=True, | |
| # insert_retrieve=False, | |
| # ): | |
| # """ | |
| # Executes a SQL query on a PostgreSQL database, handling both insert and select operations. | |
| # Parameters: | |
| # query (str): The SQL query to be executed. | |
| # params (tuple, optional): Parameters to pass into the SQL query for parameterized execution. | |
| # insert (bool, default=True): Flag to determine if the query is an insert operation (default) or a select operation. | |
| # insert_retrieve (bool, default=False): Flag to determine if the query should insert and then return the inserted ID. | |
| # """ | |
| # # Database connection parameters | |
| # dbname = db_cred["dbname"] | |
| # user = db_cred["user"] | |
| # password = db_cred["password"] | |
| # host = db_cred["host"] | |
| # port = db_cred["port"] | |
| # try: | |
| # # Establish connection to the PostgreSQL database | |
| # conn = psycopg2.connect( | |
| # dbname=dbname, user=user, password=password, host=host, port=port | |
| # ) | |
| # except psycopg2.Error as e: | |
| # st.warning(f"Unable to connect to the database: {e}") | |
| # st.stop() | |
| # # Create a cursor object to interact with the database | |
| # c = conn.cursor() | |
| # try: | |
| # # Execute the query with or without parameters | |
| # if params: | |
| # c.execute(query, params) | |
| # else: | |
| # c.execute(query) | |
| # if not insert: | |
| # # If not an insert operation, fetch and return the results | |
| # results = c.fetchall() | |
| # return results | |
| # elif insert_retrieve: | |
| # # If insert and retrieve operation, fetch and return the results | |
| # conn.commit() | |
| # return c.fetchall() | |
| # else: | |
| # conn.commit() | |
| # except Exception as e: | |
| # st.write(f"Error executing query: {e}") | |
| # finally: | |
| # conn.close() | |
| db_path = os.path.join("imp_db.db") | |
| def query_excecuter_postgres( | |
| query, db_path=None, params=None, insert=True, insert_retrieve=False, db_cred=None | |
| ): | |
| """ | |
| Executes a SQL query on a SQLite database, handling both insert and select operations. | |
| Parameters: | |
| query (str): The SQL query to be executed. | |
| db_path (str): Path to the SQLite database file. | |
| params (tuple, optional): Parameters to pass into the SQL query for parameterized execution. | |
| insert (bool, default=True): Flag to determine if the query is an insert operation (default) or a select operation. | |
| insert_retrieve (bool, default=False): Flag to determine if the query should insert and then return the inserted ID. | |
| """ | |
| try: | |
| # Construct a cross-platform path to the database | |
| db_dir = os.path.join("db") | |
| os.makedirs(db_dir, exist_ok=True) # Make sure the directory exists | |
| db_path = os.path.join(db_dir, "imp_db.db") | |
| # Establish connection to the SQLite database | |
| conn = sqlite3.connect(db_path) | |
| except sqlite3.Error as e: | |
| st.warning(f"Unable to connect to the SQLite database: {e}") | |
| st.stop() | |
| # Create a cursor object to interact with the database | |
| c = conn.cursor() | |
| # Prepare the query with proper placeholders | |
| if params: | |
| # Handle the `IN (?)` clause dynamically | |
| query = query.replace("IN (?)", f"IN ({','.join(['?' for _ in params])})") | |
| c.execute(query, params) | |
| else: | |
| c.execute(query) | |
| try: | |
| if not insert: | |
| # If not an insert operation, fetch and return the results | |
| results = c.fetchall() | |
| return results | |
| elif insert_retrieve: | |
| # If insert and retrieve operation, commit and return the last inserted row ID | |
| conn.commit() | |
| return c.lastrowid | |
| else: | |
| # For standard insert operations, commit the transaction | |
| conn.commit() | |
| except Exception as e: | |
| st.write(f"Error executing query: {e}") | |
| finally: | |
| conn.close() | |
| def update_summary_df(): | |
| """ | |
| Updates the 'project_summary_df' in the session state with the latest project | |
| summary information based on the most recent updates. | |
| This function executes a SQL query to retrieve project metadata from a database | |
| and stores the result in the session state. | |
| Uses: | |
| - query_excecuter_postgres(query, params=params, insert=False): A function that | |
| executes the provided SQL query on a PostgreSQL database. | |
| Modifies: | |
| - st.session_state['project_summary_df']: Updates the dataframe with columns: | |
| 'Project Number', 'Project Name', 'Last Modified Page', 'Last Modified Time'. | |
| """ | |
| query = f""" | |
| WITH LatestUpdates AS ( | |
| SELECT | |
| prj_id, | |
| page_nam, | |
| updt_dt_tm, | |
| ROW_NUMBER() OVER (PARTITION BY prj_id ORDER BY updt_dt_tm DESC) AS rn | |
| FROM | |
| mmo_project_meta_data | |
| ) | |
| SELECT | |
| p.prj_id, | |
| p.prj_nam AS prj_nam, | |
| lu.page_nam, | |
| lu.updt_dt_tm | |
| FROM | |
| LatestUpdates lu | |
| RIGHT JOIN | |
| mmo_projects p ON lu.prj_id = p.prj_id | |
| WHERE | |
| p.prj_ownr_id = ? AND lu.rn = 1 | |
| """ | |
| params = (st.session_state["emp_id"],) # Parameters for the SQL query | |
| # Execute the query and retrieve project summary data | |
| project_summary = query_excecuter_postgres( | |
| query, db_cred, params=params, insert=False | |
| ) | |
| # Update the session state with the project summary dataframe | |
| st.session_state["project_summary_df"] = pd.DataFrame( | |
| project_summary, | |
| columns=[ | |
| "Project Number", | |
| "Project Name", | |
| "Last Modified Page", | |
| "Last Modified Time", | |
| ], | |
| ) | |
| st.session_state["project_summary_df"] = st.session_state[ | |
| "project_summary_df" | |
| ].sort_values(by=["Last Modified Time"], ascending=False) | |
| return st.session_state["project_summary_df"] | |
| from constants import default_dct | |
| def ensure_project_dct_structure(session_state, default_dct): | |
| for key, value in default_dct.items(): | |
| if key not in session_state: | |
| session_state[key] = value | |
| elif isinstance(value, dict): | |
| ensure_project_dct_structure(session_state[key], value) | |
| def project_selection(): | |
| emp_id = st.text_input("employee id", key="emp1111").lower() | |
| password = st.text_input("Password", max_chars=15, type="password") | |
| if st.button("Login"): | |
| if "unique_ids" not in st.session_state: | |
| unique_users_query = f""" | |
| SELECT DISTINCT emp_id, emp_nam, emp_typ from mmo_users; | |
| """ | |
| unique_users_result = query_excecuter_postgres( | |
| unique_users_query, db_cred, insert=False | |
| ) # retrieves all the users who has access to MMO TOOL | |
| st.session_state["unique_ids"] = { | |
| emp_id: (emp_nam, emp_type) | |
| for emp_id, emp_nam, emp_type in unique_users_result | |
| } | |
| if emp_id not in st.session_state["unique_ids"].keys() or len(password) == 0: | |
| st.warning("invalid id or password!") | |
| st.stop() | |
| if not is_pswrd_flag_set(emp_id): | |
| st.warning("Reset password in home page to continue") | |
| st.stop() | |
| elif not verify_password(emp_id, password): | |
| st.warning("Invalid user name or password") | |
| st.stop() | |
| else: | |
| st.session_state["emp_id"] = emp_id | |
| st.session_state["username"] = st.session_state["unique_ids"][ | |
| st.session_state["emp_id"] | |
| ][0] | |
| with st.spinner("Loading Saved Projects"): | |
| st.session_state["project_summary_df"] = update_summary_df() | |
| # st.write(st.session_state["project_name"][0]) | |
| if len(st.session_state["project_summary_df"]) == 0: | |
| st.warning("No projects found please create a project in Home page") | |
| st.stop() | |
| else: | |
| try: | |
| st.session_state["project_name"] = ( | |
| st.session_state["project_summary_df"] | |
| .loc[ | |
| st.session_state["project_summary_df"]["Project Number"] | |
| == st.session_state["project_summary_df"].iloc[0, 0], | |
| "Project Name", | |
| ] | |
| .values[0] | |
| ) # fetching project name from project number stored in summary df | |
| poroject_dct_query = f""" | |
| SELECT pkl_obj FROM mmo_project_meta_data WHERE prj_id = ? AND file_nam=?; | |
| """ | |
| # Execute the query and retrieve the result | |
| project_number = int(st.session_state["project_summary_df"].iloc[0, 0]) | |
| st.session_state["project_number"] = project_number | |
| project_dct_retrieved = query_excecuter_postgres( | |
| poroject_dct_query, | |
| db_cred, | |
| params=(project_number, "project_dct"), | |
| insert=False, | |
| ) | |
| # retrieves project dict (meta data) stored in db | |
| st.session_state["project_dct"] = pickle.loads( | |
| project_dct_retrieved[0][0] | |
| ) # converting bytes data to original objet using pickle | |
| ensure_project_dct_structure( | |
| st.session_state["project_dct"], default_dct | |
| ) | |
| st.success("Project Loded") | |
| st.rerun() | |
| except Exception as e: | |
| st.write( | |
| "Failed to load project meta data from db please create new project!" | |
| ) | |
| st.stop() | |
| def update_db(prj_id, page_nam, file_nam, pkl_obj, resp_mtrc="", schema=""): | |
| # Check if an entry already exists | |
| check_query = f""" | |
| SELECT 1 FROM mmo_project_meta_data | |
| WHERE prj_id = ? AND file_nam =?; | |
| """ | |
| check_params = (prj_id, file_nam) | |
| result = query_excecuter_postgres( | |
| check_query, db_cred, params=check_params, insert=False | |
| ) | |
| # If entry exists, perform an update | |
| if result is not None and result: | |
| update_query = f""" | |
| UPDATE mmo_project_meta_data | |
| SET file_nam = ?, pkl_obj = ?, page_nam=? ,updt_dt_tm = datetime('now') | |
| WHERE prj_id = ? AND file_nam = ?; | |
| """ | |
| update_params = (file_nam, pkl_obj, page_nam, prj_id, file_nam) | |
| query_excecuter_postgres( | |
| update_query, db_cred, params=update_params, insert=True | |
| ) | |
| # If entry does not exist, perform an insert | |
| else: | |
| insert_query = f""" | |
| INSERT INTO mmo_project_meta_data | |
| (prj_id, page_nam, file_nam, pkl_obj,crte_by_uid, crte_dt_tm, updt_dt_tm) | |
| VALUES (?, ?, ?, ?, ?, datetime('now'), datetime('now')); | |
| """ | |
| insert_params = ( | |
| prj_id, | |
| page_nam, | |
| file_nam, | |
| pkl_obj, | |
| st.session_state["emp_id"], | |
| ) | |
| query_excecuter_postgres( | |
| insert_query, db_cred, params=insert_params, insert=True | |
| ) | |
| # st.success(f"Inserted project meta data for project {prj_id}, page {page_nam}") | |
| def retrieve_pkl_object(prj_id, page_nam, file_nam, schema=""): | |
| query = f""" | |
| SELECT pkl_obj FROM mmo_project_meta_data | |
| WHERE prj_id = ? AND page_nam = ? AND file_nam = ?; | |
| """ | |
| params = (prj_id, page_nam, file_nam) | |
| result = query_excecuter_postgres( | |
| query, db_cred=db_cred, params=params, insert=False | |
| ) | |
| if result and result[0] and result[0][0]: | |
| pkl_obj = result[0][0] | |
| # Deserialize the pickle object | |
| return pickle.loads(pkl_obj) | |
| else: | |
| return None | |
| def validate_text(input_text): | |
| # Check the length of the text | |
| if len(input_text) < 2: | |
| return False, "Input should be at least 2 characters long." | |
| if len(input_text) > 30: | |
| return False, "Input should not exceed 30 characters." | |
| # Check if the text contains only allowed characters | |
| if not re.match(r"^[A-Za-z0-9_]+$", input_text): | |
| return ( | |
| False, | |
| "Input contains invalid characters. Only letters, numbers and underscores are allowed.", | |
| ) | |
| return True, "Input is valid." | |
| def delete_entries(prj_id, page_names, db_cred=None, schema=None): | |
| """ | |
| Deletes all entries from the project_meta_data table based on prj_id and a list of page names. | |
| Parameters: | |
| prj_id (int): The project ID. | |
| page_names (list): A list of page names. | |
| db_cred (dict): Database credentials with keys 'dbname', 'user', 'password', 'host', 'port'. | |
| schema (str): The schema name. | |
| """ | |
| # Create placeholders for each page name in the list | |
| placeholders = ", ".join(["?"] * len(page_names)) | |
| query = f""" | |
| DELETE FROM mmo_project_meta_data | |
| WHERE prj_id = ? AND page_nam IN ({placeholders}); | |
| """ | |
| # Combine prj_id and page_names into one list of parameters | |
| params = (prj_id, *page_names) | |
| query_excecuter_postgres(query, db_cred, params=params, insert=True) | |
| # st.success(f"Deleted entries for project {prj_id}, page {page_name}") | |
| def store_hashed_password( | |
| user_id, | |
| plain_text_password, | |
| ): | |
| """ | |
| Hashes a plain text password using bcrypt, converts it to a UTF-8 string, and stores it as text. | |
| Parameters: | |
| plain_text_password (str): The plain text password to be hashed. | |
| db_cred (dict): The database credentials including dbname, user, password, host, and port. | |
| """ | |
| # Hash the plain text password | |
| hashed_password = bcrypt.hashpw( | |
| plain_text_password.encode("utf-8"), bcrypt.gensalt() | |
| ) | |
| # Convert the byte string to a regular string for storage | |
| hashed_password_str = hashed_password.decode("utf-8") | |
| # SQL query to update the pswrd_key for the specified user_id | |
| query = f""" | |
| UPDATE mmo_users | |
| SET pswrd_key = ? | |
| WHERE emp_id = ?; | |
| """ | |
| # Execute the query using the existing query_excecuter_postgres function | |
| query_excecuter_postgres( | |
| query=query, db_cred=db_cred, params=(hashed_password_str, user_id), insert=True | |
| ) | |
| def verify_password(user_id, plain_text_password): | |
| """ | |
| Verifies the plain text password against the stored hashed password for the specified user_id. | |
| Parameters: | |
| user_id (int): The ID of the user whose password is being verified. | |
| plain_text_password (str): The plain text password to verify. | |
| db_cred (dict): The database credentials including dbname, user, password, host, and port. | |
| """ | |
| # SQL query to retrieve the hashed password for the user_id | |
| query = f""" | |
| SELECT pswrd_key FROM mmo_users WHERE emp_id = ?; | |
| """ | |
| # Execute the query using the existing query_excecuter_postgres function | |
| result = query_excecuter_postgres( | |
| query=query, db_cred=db_cred, params=(user_id,), insert=False | |
| ) | |
| if result: | |
| stored_hashed_password_str = result[0][0] | |
| # Convert the stored string back to bytes | |
| stored_hashed_password = stored_hashed_password_str.encode("utf-8") | |
| if bcrypt.checkpw(plain_text_password.encode("utf-8"), stored_hashed_password): | |
| return True | |
| else: | |
| return False | |
| else: | |
| return False | |
| def update_password_in_db(user_id, plain_text_password): | |
| """ | |
| Hashes the plain text password and updates the `pswrd_key` | |
| column for the given `emp_id` in the `mmo_users` table. | |
| Parameters: | |
| emp_id (var): The ID of the user whose password needs to be updated. | |
| plain_text_password (str): The plain text password to be hashed and stored. | |
| db_cred (dict): Database credentials required to connect to the database. | |
| """ | |
| # Hash the plain text password using bcrypt | |
| hashed_password = bcrypt.hashpw( | |
| plain_text_password.encode("utf-8"), bcrypt.gensalt() | |
| ) | |
| # Convert the hashed password from bytes to a string for storage | |
| hashed_password_str = hashed_password.decode("utf-8") | |
| # SQL query to update the password in the database | |
| query = f""" | |
| UPDATE mmo_users | |
| SET pswrd_key = ? | |
| WHERE emp_id = ? | |
| """ | |
| # Parameters for the query | |
| params = (hashed_password_str, user_id) | |
| # Execute the query using the query_excecuter_postgres function | |
| query_excecuter_postgres(query, db_cred, params=params, insert=True) | |
| def is_pswrd_flag_set(user_id): | |
| query = f""" | |
| SELECT pswrd_flag | |
| FROM mmo_users | |
| WHERE emp_id = ?; | |
| """ | |
| # Execute the query | |
| result = query_excecuter_postgres(query, db_cred, params=(user_id,), insert=False) | |
| # Return True if the flag is 1, otherwise return False | |
| if result and result[0][0] == 1: | |
| return True | |
| else: | |
| return False | |
| def set_pswrd_flag(user_id): | |
| query = f""" | |
| UPDATE mmo_users | |
| SET pswrd_flag = 1 | |
| WHERE emp_id = ?; | |
| """ | |
| # Execute the update query | |
| query_excecuter_postgres(query, db_cred, params=(user_id,), insert=True) | |
| def retrieve_pkl_object_without_warning(prj_id, page_nam, file_nam, schema): | |
| query = f""" | |
| SELECT pkl_obj FROM mmo_project_meta_data | |
| WHERE prj_id = ? AND page_nam = ? AND file_nam = ?; | |
| """ | |
| params = (prj_id, page_nam, file_nam) | |
| result = query_excecuter_postgres( | |
| query, db_cred=db_cred, params=params, insert=False | |
| ) | |
| if result and result[0] and result[0][0]: | |
| pkl_obj = result[0][0] | |
| # Deserialize the pickle object | |
| return pickle.loads(pkl_obj) | |
| else: | |
| # st.warning( | |
| # "Pickle object not found for the given project ID, page name, and file name." | |
| # ) | |
| return None | |
| color_palette = [ | |
| "#F3F3F0", | |
| "#5E7D7E", | |
| "#2FA1FF", | |
| "#00EDED", | |
| "#00EAE4", | |
| "#304550", | |
| "#EDEBEB", | |
| "#7FBEFD", | |
| "#003059", | |
| "#A2F3F3", | |
| "#E1D6E2", | |
| "#B6B6B6", | |
| ] | |
| CURRENCY_INDICATOR = "$" | |
| # database_file = r"DB/User.db" | |
| # conn = sqlite3.connect(database_file, check_same_thread=False) # connection with sql db | |
| # c = conn.cursor() | |
| # def load_authenticator(): | |
| # with open("config.yaml") as file: | |
| # config = yaml.load(file, Loader=SafeLoader) | |
| # st.session_state["config"] = config | |
| # authenticator = stauth.Authenticate( | |
| # credentials=config["credentials"], | |
| # cookie_name=config["cookie"]["name"], | |
| # key=config["cookie"]["key"], | |
| # cookie_expiry_days=config["cookie"]["expiry_days"], | |
| # preauthorized=config["preauthorized"], | |
| # ) | |
| # st.session_state["authenticator"] = authenticator | |
| # return authenticator | |
| # Authentication | |
| # def authenticator(): | |
| # for k, v in st.session_state.items(): | |
| # if k not in ["logout", "login", "config"] and not k.startswith( | |
| # "FormSubmitter" | |
| # ): | |
| # st.session_state[k] = v | |
| # with open("config.yaml") as file: | |
| # config = yaml.load(file, Loader=SafeLoader) | |
| # st.session_state["config"] = config | |
| # authenticator = stauth.Authenticate( | |
| # config["credentials"], | |
| # config["cookie"]["name"], | |
| # config["cookie"]["key"], | |
| # config["cookie"]["expiry_days"], | |
| # config["preauthorized"], | |
| # ) | |
| # st.session_state["authenticator"] = authenticator | |
| # name, authentication_status, username = authenticator.login( | |
| # "Login", "main" | |
| # ) | |
| # auth_status = st.session_state.get("authentication_status") | |
| # if auth_status == True: | |
| # authenticator.logout("Logout", "main") | |
| # is_state_initiaized = st.session_state.get("initialized", False) | |
| # if not is_state_initiaized: | |
| # if "session_name" not in st.session_state: | |
| # st.session_state["session_name"] = None | |
| # return name | |
| # def authentication(): | |
| # with open("config.yaml") as file: | |
| # config = yaml.load(file, Loader=SafeLoader) | |
| # authenticator = stauth.Authenticate( | |
| # config["credentials"], | |
| # config["cookie"]["name"], | |
| # config["cookie"]["key"], | |
| # config["cookie"]["expiry_days"], | |
| # config["preauthorized"], | |
| # ) | |
| # name, authentication_status, username = authenticator.login( | |
| # "Login", "main" | |
| # ) | |
| # return authenticator, name, authentication_status, username | |
| def nav_page(page_name, timeout_secs=3): | |
| nav_script = """ | |
| <script type="text/javascript"> | |
| function attempt_nav_page(page_name, start_time, timeout_secs) { | |
| var links = window.parent.document.getElementsByTagName("a"); | |
| for (var i = 0; i < links.length; i++) { | |
| if (links[i].href.toLowerCase().endsWith("/" + page_name.toLowerCase())) { | |
| links[i].click(); | |
| return; | |
| } | |
| } | |
| var elasped = new Date() - start_time; | |
| if (elasped < timeout_secs * 1000) { | |
| setTimeout(attempt_nav_page, 100, page_name, start_time, timeout_secs); | |
| } else { | |
| alert("Unable to navigate to page '" + page_name + "' after " + timeout_secs + " second(s)."); | |
| } | |
| } | |
| window.addEventListener("load", function() { | |
| attempt_nav_page("%s", new Date(), %d); | |
| }); | |
| </script> | |
| """ % ( | |
| page_name, | |
| timeout_secs, | |
| ) | |
| html(nav_script) | |
| # def load_local_css(file_name): | |
| # with open(file_name) as f: | |
| # st.markdown(f'<style>{f.read()}</style>', unsafe_allow_html=True) | |
| # def set_header(): | |
| # return st.markdown(f"""<div class='main-header'> | |
| # <h1>MMM LiME</h1> | |
| # <img src="https://assets-global.website-files.com/64c8fffb0e95cbc525815b79/64df84637f83a891c1473c51_Vector%20(Stroke).svg "> | |
| # </div>""", unsafe_allow_html=True) | |
| path = os.path.dirname(__file__) | |
| file_ = open(f"{path}/logo.png", "rb") | |
| contents = file_.read() | |
| data_url = base64.b64encode(contents).decode("utf-8") | |
| file_.close() | |
| DATA_PATH = "./data" | |
| IMAGES_PATH = "./data/images_224_224" | |
| def load_local_css(file_name): | |
| with open(file_name) as f: | |
| st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True) | |
| # def set_header(): | |
| # return st.markdown(f"""<div class='main-header'> | |
| # <h1>H & M Recommendations</h1> | |
| # <img src="data:image;base64,{data_url}", alt="Logo"> | |
| # </div>""", unsafe_allow_html=True) | |
| path1 = os.path.dirname(__file__) | |
| # file_1 = open(f"{path}/willbank.png", "rb") | |
| # contents1 = file_1.read() | |
| # data_url1 = base64.b64encode(contents1).decode("utf-8") | |
| # file_1.close() | |
| DATA_PATH1 = "./data" | |
| IMAGES_PATH1 = "./data/images_224_224" | |
| def set_header(): | |
| return st.markdown( | |
| f"""<div class='main-header'> | |
| <!-- <h1></h1> --> | |
| <div > | |
| <img class='blend-logo' src="data:image;base64,{data_url}", alt="Logo"> | |
| </div>""", | |
| unsafe_allow_html=True, | |
| ) | |
| # def set_header(): | |
| # logo_path = "./path/to/your/local/LIME_logo.png" # Replace with the actual file path | |
| # text = "LiME" | |
| # return st.markdown(f"""<div class='main-header'> | |
| # <img src="data:image/png;base64,{data_url}" alt="Logo" style="float: left; margin-right: 10px; width: 100px; height: auto;"> | |
| # <h1>{text}</h1> | |
| # </div>""", unsafe_allow_html=True) | |
| def s_curve(x, K, b, a, x0): | |
| return K / (1 + b * np.exp(-a * (x - x0))) | |
| def panel_level(input_df, date_column="Date"): | |
| # Ensure 'Date' is set as the index | |
| if date_column not in input_df.index.names: | |
| input_df = input_df.set_index(date_column) | |
| # Select numeric columns only (excluding 'Date' since it's now the index) | |
| numeric_columns_df = input_df.select_dtypes(include="number") | |
| # Group by 'Date' (which is the index) and sum the numeric columns | |
| aggregated_df = numeric_columns_df.groupby(input_df.index).sum() | |
| # Reset the index to bring the 'Date' column | |
| aggregated_df = aggregated_df.reset_index() | |
| return aggregated_df | |
| def fetch_actual_data( | |
| panel=None, | |
| target_file="Overview_data_test.xlsx", | |
| updated_rcs=None, | |
| metrics=None, | |
| ): | |
| excel = pd.read_excel(Path(target_file), sheet_name=None) | |
| # Extract dataframes for raw data, spend input, and contribution MMM | |
| raw_df = excel["RAW DATA MMM"] | |
| spend_df = excel["SPEND INPUT"] | |
| contri_df = excel["CONTRIBUTION MMM"] | |
| # Check if the panel is not None | |
| if panel is not None and panel != "Aggregated": | |
| raw_df = raw_df[raw_df["Panel"] == panel].drop(columns=["Panel"]) | |
| spend_df = spend_df[spend_df["Panel"] == panel].drop(columns=["Panel"]) | |
| contri_df = contri_df[contri_df["Panel"] == panel].drop(columns=["Panel"]) | |
| elif panel == "Aggregated": | |
| raw_df = panel_level(raw_df, date_column="Date") | |
| spend_df = panel_level(spend_df, date_column="Week") | |
| contri_df = panel_level(contri_df, date_column="Date") | |
| # Revenue_df = excel['Revenue'] | |
| ## remove sesonalities, indices etc ... | |
| unnamed_cols = [col for col in raw_df.columns if col.lower().startswith("unnamed")] | |
| ## remove sesonalities, indices etc ... | |
| exclude_columns = [ | |
| "Date", | |
| "Region", | |
| "Controls_Grammarly_Index_SeasonalAVG", | |
| "Controls_Quillbot_Index", | |
| "Daily_Positive_Outliers", | |
| "External_RemoteClass_Index", | |
| "Intervals ON 20190520-20190805 | 20200518-20200803 | 20210517-20210802", | |
| "Intervals ON 20190826-20191209 | 20200824-20201207 | 20210823-20211206", | |
| "Intervals ON 20201005-20201019", | |
| "Promotion_PercentOff", | |
| "Promotion_TimeBased", | |
| "Seasonality_Indicator_Chirstmas", | |
| "Seasonality_Indicator_NewYears_Days", | |
| "Seasonality_Indicator_Thanksgiving", | |
| "Trend 20200302 / 20200803", | |
| ] + unnamed_cols | |
| raw_df["Date"] = pd.to_datetime(raw_df["Date"]) | |
| contri_df["Date"] = pd.to_datetime(contri_df["Date"]) | |
| input_df = raw_df.sort_values(by="Date") | |
| output_df = contri_df.sort_values(by="Date") | |
| spend_df["Week"] = pd.to_datetime( | |
| spend_df["Week"], format="%Y-%m-%d", errors="coerce" | |
| ) | |
| spend_df.sort_values(by="Week", inplace=True) | |
| # spend_df['Week'] = pd.to_datetime(spend_df['Week'], errors='coerce') | |
| # spend_df = spend_df.sort_values(by='Week') | |
| channel_list = [col for col in input_df.columns if col not in exclude_columns] | |
| channel_list = list(set(channel_list) - set(["fb_level_achieved_tier_1", "ga_app"])) | |
| infeasible_channels = [ | |
| c | |
| for c in contri_df.select_dtypes(include=["float", "int"]).columns | |
| if contri_df[c].sum() <= 0 | |
| ] | |
| # st.write(channel_list) | |
| channel_list = list(set(channel_list) - set(infeasible_channels)) | |
| upper_limits = {} | |
| output_cols = [] | |
| actual_output_dic = {} | |
| actual_input_dic = {} | |
| for inp_col in channel_list: | |
| # st.write(inp_col) | |
| spends = input_df[inp_col].values | |
| x = spends.copy() | |
| # upper limit for penalty | |
| upper_limits[inp_col] = 2 * x.max() | |
| # contribution | |
| # out_col = [_col for _col in output_df.columns if _col.startswith(inp_col)][0] | |
| out_col = inp_col | |
| y = output_df[out_col].values.copy() | |
| actual_output_dic[inp_col] = y.copy() | |
| actual_input_dic[inp_col] = x.copy() | |
| ##output cols aggregation | |
| output_cols.append(out_col) | |
| return pd.DataFrame(actual_input_dic), pd.DataFrame(actual_output_dic) | |
| # Function to initialize model results data | |
| def initialize_data(panel=None, metrics=None): | |
| # Extract dataframes for raw data, spend input, and contribution data | |
| raw_df = st.session_state["project_dct"]["current_media_performance"][ | |
| "model_outputs" | |
| ][metrics]["raw_data"].copy() | |
| spend_df = st.session_state["project_dct"]["current_media_performance"][ | |
| "model_outputs" | |
| ][metrics]["spends_data"].copy() | |
| contribution_df = st.session_state["project_dct"]["current_media_performance"][ | |
| "model_outputs" | |
| ][metrics]["contribution_data"].copy() | |
| # Check if 'Panel' or 'panel' is in the columns | |
| panel_column = None | |
| if "Panel" in raw_df.columns: | |
| panel_column = "Panel" | |
| elif "panel" in raw_df.columns: | |
| panel_column = "panel" | |
| # Filter data by panel if provided | |
| if panel and panel.lower() != "aggregated": | |
| raw_df = raw_df[raw_df[panel_column] == panel].drop(columns=[panel_column]) | |
| spend_df = spend_df[spend_df[panel_column] == panel].drop( | |
| columns=[panel_column] | |
| ) | |
| contribution_df = contribution_df[contribution_df[panel_column] == panel].drop( | |
| columns=[panel_column] | |
| ) | |
| else: | |
| raw_df = panel_level(raw_df, date_column="Date") | |
| spend_df = panel_level(spend_df, date_column="Date") | |
| contribution_df = panel_level(contribution_df, date_column="Date") | |
| # Remove unnecessary columns | |
| unnamed_cols = [col for col in raw_df.columns if col.lower().startswith("unnamed")] | |
| exclude_columns = ["Date"] + unnamed_cols | |
| # Convert Date columns to datetime | |
| for df in [raw_df, spend_df, contribution_df]: | |
| df["Date"] = pd.to_datetime(df["Date"], format="%Y-%m-%d", errors="coerce") | |
| # Sort data by Date | |
| input_df = raw_df.sort_values(by="Date") | |
| contribution_df = contribution_df.sort_values(by="Date") | |
| spend_df.sort_values(by="Date", inplace=True) | |
| # Extract channels excluding unwanted columns | |
| channel_list = [col for col in input_df.columns if col not in exclude_columns] | |
| # Filter out channels with non-positive contributions | |
| negative_contributions = [ | |
| col | |
| for col in contribution_df.select_dtypes(include=["float", "int"]).columns | |
| if contribution_df[col].sum() <= 0 | |
| ] | |
| channel_list = list(set(channel_list) - set(negative_contributions)) | |
| # Initialize dictionaries for metrics and response curves | |
| response_curves, mapes, rmses, upper_limits = {}, {}, {}, {} | |
| r2_scores, powers, conversion_rates, actual_output, actual_input = ( | |
| {}, | |
| {}, | |
| {}, | |
| {}, | |
| {}, | |
| ) | |
| channels = {} | |
| sales = None | |
| dates = input_df["Date"].values | |
| # Fit s-curve for each channel | |
| for channel in channel_list: | |
| spends = input_df[channel].values | |
| x = spends.copy() | |
| upper_limits[channel] = 2 * x.max() | |
| # Get corresponding output column | |
| output_col = [ | |
| _col for _col in contribution_df.columns if _col.startswith(channel) | |
| ][0] | |
| y = contribution_df[output_col].values.copy() | |
| actual_output[channel] = y.copy() | |
| actual_input[channel] = x.copy() | |
| # Scale input data | |
| power = np.ceil(np.log(x.max()) / np.log(10)) - 3 | |
| if power >= 0: | |
| x = x / 10**power | |
| x, y = x.astype("float64"), y.astype("float64") | |
| # Set bounds for curve fitting | |
| if y.max() <= 0.01: | |
| bounds = ( | |
| (0, 0, 0, 0), | |
| (3 * 0.01, 1000, 1, x.max() if x.max() > 0 else 0.01), | |
| ) | |
| else: | |
| bounds = ((0, 0, 0, 0), (3 * y.max(), 1000, 1, x.max())) | |
| # Set y to 0 where x is 0 | |
| y[x == 0] = 0 | |
| # Fit s-curve and calculate metrics | |
| # params, _ = curve_fit( | |
| # s_curve, | |
| # x | |
| # y, | |
| # p0=(2 * y.max(), 0.01, 1e-5, x.max()), | |
| # bounds=bounds, | |
| # maxfev=int(1e6), | |
| # ) | |
| params, _ = curve_fit( | |
| s_curve, | |
| list(x) + [0] * len(x), | |
| list(y) + [0] * len(y), | |
| p0=(2 * y.max(), 0.01, 1e-5, x.max()), | |
| bounds=bounds, | |
| maxfev=int(1e6), | |
| ) | |
| mape = (100 * abs(1 - s_curve(x, *params) / y.clip(min=1))).mean() | |
| rmse = np.sqrt(((y - s_curve(x, *params)) ** 2).mean()) | |
| r2_score_ = r2_score(y, s_curve(x, *params)) | |
| # Store metrics and parameters | |
| response_curves[channel] = { | |
| "K": params[0], | |
| "b": params[1], | |
| "a": params[2], | |
| "x0": params[3], | |
| } | |
| mapes[channel] = mape | |
| rmses[channel] = rmse | |
| r2_scores[channel] = r2_score_ | |
| powers[channel] = power | |
| conversion_rate = spend_df[channel].sum() / max(input_df[channel].sum(), 1e-9) | |
| conversion_rates[channel] = conversion_rate | |
| correction = y - s_curve(x, *params) | |
| # Initialize Channel object | |
| channel_obj = Channel( | |
| name=channel, | |
| dates=dates, | |
| spends=spends, | |
| conversion_rate=conversion_rate, | |
| response_curve_type="s-curve", | |
| response_curve_params={ | |
| "K": params[0], | |
| "b": params[1], | |
| "a": params[2], | |
| "x0": params[3], | |
| }, | |
| bounds=np.array([-10, 10]), | |
| correction=correction, | |
| ) | |
| channels[channel] = channel_obj | |
| if sales is None: | |
| sales = channel_obj.actual_sales | |
| else: | |
| sales += channel_obj.actual_sales | |
| # Calculate other contributions | |
| other_contributions = ( | |
| contribution_df.drop(columns=[*response_curves.keys()]) | |
| .sum(axis=1, numeric_only=True) | |
| .values | |
| ) | |
| # Initialize Scenario object | |
| scenario = Scenario( | |
| name="default", | |
| channels=channels, | |
| constant=other_contributions, | |
| correction=np.array([]), | |
| ) | |
| # Set session state variables | |
| st.session_state.update( | |
| { | |
| "initialized": True, | |
| "actual_df": input_df, | |
| "raw_df": raw_df, | |
| "contri_df": contribution_df, | |
| "default_scenario_dict": class_to_dict(scenario), | |
| "scenario": scenario, | |
| "channels_list": channel_list, | |
| "optimization_channels": { | |
| channel_name: False for channel_name in channel_list | |
| }, | |
| "rcs": response_curves.copy(), | |
| "powers": powers, | |
| "actual_contribution_df": pd.DataFrame(actual_output), | |
| "actual_input_df": pd.DataFrame(actual_input), | |
| "xlsx_buffer": io.BytesIO(), | |
| "saved_scenarios": ( | |
| pickle.load(open("../saved_scenarios.pkl", "rb")) | |
| if Path("../saved_scenarios.pkl").exists() | |
| else OrderedDict() | |
| ), | |
| "disable_download_button": True, | |
| } | |
| ) | |
| for channel in channels.values(): | |
| st.session_state[channel.name] = numerize( | |
| channel.actual_total_spends * channel.conversion_rate, 1 | |
| ) | |
| # Prepare response curve data for output | |
| response_curve_data = {} | |
| for channel, params in st.session_state["rcs"].items(): | |
| x = st.session_state["actual_input_df"][channel].values.astype(float) | |
| y = st.session_state["actual_contribution_df"][channel].values.astype(float) | |
| power = float(np.ceil(np.log(max(x)) / np.log(10)) - 3) | |
| x_plot = list(np.linspace(0, 5 * max(x), 100)) | |
| response_curve_data[channel] = { | |
| "K": float(params["K"]), | |
| "b": float(params["b"]), | |
| "a": float(params["a"]), | |
| "x0": float(params["x0"]), | |
| "power": power, | |
| "x": list(x), | |
| "y": list(y), | |
| "x_plot": x_plot, | |
| } | |
| return response_curve_data, scenario | |
| # def initialize_data(panel=None, metrics=None): | |
| # # Extract dataframes for raw data, spend input, and contribution data | |
| # raw_df = st.session_state["project_dct"]["current_media_performance"][ | |
| # "model_outputs" | |
| # ][metrics]["raw_data"] | |
| # spend_df = st.session_state["project_dct"]["current_media_performance"][ | |
| # "model_outputs" | |
| # ][metrics]["spends_data"] | |
| # contri_df = st.session_state["project_dct"]["current_media_performance"][ | |
| # "model_outputs" | |
| # ][metrics]["contribution_data"] | |
| # # Check if the panel is not None | |
| # if panel is not None and panel.lower() != "aggregated": | |
| # raw_df = raw_df[raw_df["Panel"] == panel].drop(columns=["Panel"]) | |
| # spend_df = spend_df[spend_df["Panel"] == panel].drop(columns=["Panel"]) | |
| # contri_df = contri_df[contri_df["Panel"] == panel].drop(columns=["Panel"]) | |
| # elif panel.lower() == "aggregated": | |
| # raw_df = panel_level(raw_df, date_column="Date") | |
| # spend_df = panel_level(spend_df, date_column="Date") | |
| # contri_df = panel_level(contri_df, date_column="Date") | |
| # ## remove sesonalities, indices etc ... | |
| # unnamed_cols = [col for col in raw_df.columns if col.lower().startswith("unnamed")] | |
| # ## remove sesonalities, indices etc ... | |
| # exclude_columns = ["Date"] + unnamed_cols | |
| # raw_df["Date"] = pd.to_datetime(raw_df["Date"], format="%Y-%m-%d", errors="coerce") | |
| # contri_df["Date"] = pd.to_datetime( | |
| # contri_df["Date"], format="%Y-%m-%d", errors="coerce" | |
| # ) | |
| # spend_df["Date"] = pd.to_datetime( | |
| # spend_df["Date"], format="%Y-%m-%d", errors="coerce" | |
| # ) | |
| # input_df = raw_df.sort_values(by="Date") | |
| # output_df = contri_df.sort_values(by="Date") | |
| # spend_df.sort_values(by="Date", inplace=True) | |
| # channel_list = [col for col in input_df.columns if col not in exclude_columns] | |
| # negative_contribution = [ | |
| # c | |
| # for c in contri_df.select_dtypes(include=["float", "int"]).columns | |
| # if contri_df[c].sum() <= 0 | |
| # ] | |
| # channel_list = list(set(channel_list) - set(negative_contribution)) | |
| # response_curves = {} | |
| # mapes = {} | |
| # rmses = {} | |
| # upper_limits = {} | |
| # powers = {} | |
| # r2 = {} | |
| # conv_rates = {} | |
| # output_cols = [] | |
| # channels = {} | |
| # sales = None | |
| # dates = input_df.Date.values | |
| # actual_output_dic = {} | |
| # actual_input_dic = {} | |
| # for inp_col in channel_list: | |
| # spends = input_df[inp_col].values | |
| # x = spends.copy() | |
| # # upper limit for penalty | |
| # upper_limits[inp_col] = 2 * x.max() | |
| # # contribution | |
| # out_col = [_col for _col in output_df.columns if _col.startswith(inp_col)][0] | |
| # y = output_df[out_col].values.copy() | |
| # actual_output_dic[inp_col] = y.copy() | |
| # actual_input_dic[inp_col] = x.copy() | |
| # ##output cols aggregation | |
| # output_cols.append(out_col) | |
| # ## scale the input | |
| # power = np.ceil(np.log(x.max()) / np.log(10)) - 3 | |
| # if power >= 0: | |
| # x = x / 10**power | |
| # x = x.astype("float64") | |
| # y = y.astype("float64") | |
| # if y.max() <= 0.01: | |
| # if x.max() <= 0.0: | |
| # bounds = ((0, 0, 0, 0), (3 * 0.01, 1000, 1, 0.01)) | |
| # else: | |
| # bounds = ((0, 0, 0, 0), (3 * 0.01, 1000, 1, x.max())) | |
| # else: | |
| # bounds = ((0, 0, 0, 0), (3 * y.max(), 1000, 1, x.max())) | |
| # params, _ = curve_fit( | |
| # s_curve, | |
| # x, | |
| # y, | |
| # p0=(2 * y.max(), 0.01, 1e-5, x.max()), | |
| # bounds=bounds, | |
| # maxfev=int(1e5), | |
| # ) | |
| # mape = (100 * abs(1 - s_curve(x, *params) / y.clip(min=1))).mean() | |
| # rmse = np.sqrt(((y - s_curve(x, *params)) ** 2).mean()) | |
| # r2_ = r2_score(y, s_curve(x, *params)) | |
| # response_curves[inp_col] = { | |
| # "K": params[0], | |
| # "b": params[1], | |
| # "a": params[2], | |
| # "x0": params[3], | |
| # } | |
| # mapes[inp_col] = mape | |
| # rmses[inp_col] = rmse | |
| # r2[inp_col] = r2_ | |
| # powers[inp_col] = power | |
| # conv = spend_df[inp_col].sum() / max(input_df[inp_col].sum(), 1e-9) | |
| # conv_rates[inp_col] = conv | |
| # correction = y - s_curve(x, *params) | |
| # channel = Channel( | |
| # name=inp_col, | |
| # dates=dates, | |
| # spends=spends, | |
| # conversion_rate=conv_rates[inp_col], | |
| # response_curve_type="s-curve", | |
| # response_curve_params={ | |
| # "K": params[0], | |
| # "b": params[1], | |
| # "a": params[2], | |
| # "x0": params[3], | |
| # }, | |
| # bounds=np.array([-10, 10]), | |
| # correction=correction, | |
| # ) | |
| # channels[inp_col] = channel | |
| # if sales is None: | |
| # sales = channel.actual_sales | |
| # else: | |
| # sales += channel.actual_sales | |
| # other_contributions = ( | |
| # output_df.drop([*output_cols], axis=1).sum(axis=1, numeric_only=True).values | |
| # ) | |
| # scenario = Scenario( | |
| # name="default", | |
| # channels=channels, | |
| # constant=other_contributions, | |
| # correction=np.array([]), | |
| # ) | |
| # ## setting session variables | |
| # st.session_state["initialized"] = True | |
| # st.session_state["actual_df"] = input_df | |
| # st.session_state["raw_df"] = raw_df | |
| # st.session_state["contri_df"] = output_df | |
| # default_scenario_dict = class_to_dict(scenario) | |
| # st.session_state["default_scenario_dict"] = default_scenario_dict | |
| # st.session_state["scenario"] = scenario | |
| # st.session_state["channels_list"] = channel_list | |
| # st.session_state["optimization_channels"] = { | |
| # channel_name: False for channel_name in channel_list | |
| # } | |
| # st.session_state["rcs"] = response_curves.copy() | |
| # st.session_state["powers"] = powers | |
| # st.session_state["actual_contribution_df"] = pd.DataFrame(actual_output_dic) | |
| # st.session_state["actual_input_df"] = pd.DataFrame(actual_input_dic) | |
| # for channel in channels.values(): | |
| # st.session_state[channel.name] = numerize( | |
| # channel.actual_total_spends * channel.conversion_rate, 1 | |
| # ) | |
| # st.session_state["xlsx_buffer"] = io.BytesIO() | |
| # if Path("../saved_scenarios.pkl").exists(): | |
| # with open("../saved_scenarios.pkl", "rb") as f: | |
| # st.session_state["saved_scenarios"] = pickle.load(f) | |
| # else: | |
| # st.session_state["saved_scenarios"] = OrderedDict() | |
| # # st.session_state["total_spends_change"] = 0 | |
| # st.session_state["optimization_channels"] = { | |
| # channel_name: False for channel_name in channel_list | |
| # } | |
| # st.session_state["disable_download_button"] = True | |
| # rcs_data = {} | |
| # for channel in st.session_state["rcs"]: | |
| # # Convert to native Python lists and types | |
| # x = list(st.session_state["actual_input_df"][channel].values.astype(float)) | |
| # y = list( | |
| # st.session_state["actual_contribution_df"][channel].values.astype(float) | |
| # ) | |
| # power = float(np.ceil(np.log(max(x)) / np.log(10)) - 3) | |
| # x_plot = list(np.linspace(0, 5 * max(x), 100)) | |
| # rcs_data[channel] = { | |
| # "K": float(st.session_state["rcs"][channel]["K"]), | |
| # "b": float(st.session_state["rcs"][channel]["b"]), | |
| # "a": float(st.session_state["rcs"][channel]["a"]), | |
| # "x0": float(st.session_state["rcs"][channel]["x0"]), | |
| # "power": power, | |
| # "x": x, | |
| # "y": y, | |
| # "x_plot": x_plot, | |
| # } | |
| # return rcs_data, scenario | |
| # def initialize_data(): | |
| # # fetch data from excel | |
| # output = pd.read_excel('data.xlsx',sheet_name=None) | |
| # raw_df = output['RAW DATA MMM'] | |
| # contribution_df = output['CONTRIBUTION MMM'] | |
| # Revenue_df = output['Revenue'] | |
| # ## channels to be shows | |
| # channel_list = [] | |
| # for col in raw_df.columns: | |
| # if 'click' in col.lower() or 'spend' in col.lower() or 'imp' in col.lower(): | |
| # channel_list.append(col) | |
| # else: | |
| # pass | |
| # ## NOTE : Considered only Desktop spends for all calculations | |
| # acutal_df = raw_df[raw_df.Region == 'Desktop'].copy() | |
| # ## NOTE : Considered one year of data | |
| # acutal_df = acutal_df[acutal_df.Date>'2020-12-31'] | |
| # actual_df = acutal_df.drop('Region',axis=1).sort_values(by='Date')[[*channel_list,'Date']] | |
| # ##load response curves | |
| # with open('./grammarly_response_curves.json','r') as f: | |
| # response_curves = json.load(f) | |
| # ## create channel dict for scenario creation | |
| # dates = actual_df.Date.values | |
| # channels = {} | |
| # rcs = {} | |
| # constant = 0. | |
| # for i,info_dict in enumerate(response_curves): | |
| # name = info_dict.get('name') | |
| # response_curve_type = info_dict.get('response_curve') | |
| # response_curve_params = info_dict.get('params') | |
| # rcs[name] = response_curve_params | |
| # if name != 'constant': | |
| # spends = actual_df[name].values | |
| # channel = Channel(name=name,dates=dates, | |
| # spends=spends, | |
| # response_curve_type=response_curve_type, | |
| # response_curve_params=response_curve_params, | |
| # bounds=np.array([-30,30])) | |
| # channels[name] = channel | |
| # else: | |
| # constant = info_dict.get('value',0.) * len(dates) | |
| # ## create scenario | |
| # scenario = Scenario(name='default', channels=channels, constant=constant) | |
| # default_scenario_dict = class_to_dict(scenario) | |
| # ## setting session variables | |
| # st.session_state['initialized'] = True | |
| # st.session_state['actual_df'] = actual_df | |
| # st.session_state['raw_df'] = raw_df | |
| # st.session_state['default_scenario_dict'] = default_scenario_dict | |
| # st.session_state['scenario'] = scenario | |
| # st.session_state['channels_list'] = channel_list | |
| # st.session_state['optimization_channels'] = {channel_name : False for channel_name in channel_list} | |
| # st.session_state['rcs'] = rcs | |
| # for channel in channels.values(): | |
| # if channel.name not in st.session_state: | |
| # st.session_state[channel.name] = float(channel.actual_total_spends) | |
| # if 'xlsx_buffer' not in st.session_state: | |
| # st.session_state['xlsx_buffer'] = io.BytesIO() | |
| # ## for saving scenarios | |
| # if 'saved_scenarios' not in st.session_state: | |
| # if Path('../saved_scenarios.pkl').exists(): | |
| # with open('../saved_scenarios.pkl','rb') as f: | |
| # st.session_state['saved_scenarios'] = pickle.load(f) | |
| # else: | |
| # st.session_state['saved_scenarios'] = OrderedDict() | |
| # if 'total_spends_change' not in st.session_state: | |
| # st.session_state['total_spends_change'] = 0 | |
| # if 'optimization_channels' not in st.session_state: | |
| # st.session_state['optimization_channels'] = {channel_name : False for channel_name in channel_list} | |
| # if 'disable_download_button' not in st.session_state: | |
| # st.session_state['disable_download_button'] = True | |
| def create_channel_summary(scenario): | |
| # Provided data | |
| data = { | |
| "Channel": [ | |
| "Paid Search", | |
| "Ga will cid baixo risco", | |
| "Digital tactic others", | |
| "Fb la tier 1", | |
| "Fb la tier 2", | |
| "Paid social others", | |
| "Programmatic", | |
| "Kwai", | |
| "Indicacao", | |
| "Infleux", | |
| "Influencer", | |
| ], | |
| "Spends": [ | |
| "$ 11.3K", | |
| "$ 155.2K", | |
| "$ 50.7K", | |
| "$ 125.4K", | |
| "$ 125.2K", | |
| "$ 105K", | |
| "$ 3.3M", | |
| "$ 47.5K", | |
| "$ 55.9K", | |
| "$ 632.3K", | |
| "$ 48.3K", | |
| ], | |
| "Revenue": [ | |
| "558.0K", | |
| "3.5M", | |
| "5.2M", | |
| "3.1M", | |
| "3.1M", | |
| "2.1M", | |
| "20.8M", | |
| "1.6M", | |
| "728.4K", | |
| "22.9M", | |
| "4.8M", | |
| ], | |
| } | |
| # Create DataFrame | |
| df = pd.DataFrame(data) | |
| # Convert currency strings to numeric values | |
| df["Spends"] = ( | |
| df["Spends"] | |
| .replace({"\$": "", "K": "*1e3", "M": "*1e6"}, regex=True) | |
| .map(pd.eval) | |
| .astype(int) | |
| ) | |
| df["Revenue"] = ( | |
| df["Revenue"] | |
| .replace({"\$": "", "K": "*1e3", "M": "*1e6"}, regex=True) | |
| .map(pd.eval) | |
| .astype(int) | |
| ) | |
| # Calculate ROI | |
| df["ROI"] = (df["Revenue"] - df["Spends"]) / df["Spends"] | |
| # Format columns | |
| format_currency = lambda x: f"${x:,.1f}" | |
| format_roi = lambda x: f"{x:.1f}" | |
| df["Spends"] = [ | |
| "$ 11.3K", | |
| "$ 155.2K", | |
| "$ 50.7K", | |
| "$ 125.4K", | |
| "$ 125.2K", | |
| "$ 105K", | |
| "$ 3.3M", | |
| "$ 47.5K", | |
| "$ 55.9K", | |
| "$ 632.3K", | |
| "$ 48.3K", | |
| ] | |
| df["Revenue"] = [ | |
| "$ 536.3K", | |
| "$ 3.4M", | |
| "$ 5M", | |
| "$ 3M", | |
| "$ 3M", | |
| "$ 2M", | |
| "$ 20M", | |
| "$ 1.5M", | |
| "$ 7.1M", | |
| "$ 22M", | |
| "$ 4.6M", | |
| ] | |
| df["ROI"] = df["ROI"].apply(format_roi) | |
| return df | |
| # @st.cache(allow_output_mutation=True) | |
| # def create_contribution_pie(scenario): | |
| # #c1f7dc | |
| # colors_map = {col:color for col,color in zip(st.session_state['channels_list'],plotly.colors.n_colors(plotly.colors.hex_to_rgb('#BE6468'), plotly.colors.hex_to_rgb('#E7B8B7'),23))} | |
| # total_contribution_fig = make_subplots(rows=1, cols=2,subplot_titles=['Spends','Revenue'],specs=[[{"type": "pie"}, {"type": "pie"}]]) | |
| # total_contribution_fig.add_trace( | |
| # go.Pie(labels=[channel_name_formating(channel_name) for channel_name in st.session_state['channels_list']] + ['Non Media'], | |
| # values= [round(scenario.channels[channel_name].actual_total_spends * scenario.channels[channel_name].conversion_rate,1) for channel_name in st.session_state['channels_list']] + [0], | |
| # marker=dict(colors = [plotly.colors.label_rgb(colors_map[channel_name]) for channel_name in st.session_state['channels_list']] + ['#F0F0F0']), | |
| # hole=0.3), | |
| # row=1, col=1) | |
| # total_contribution_fig.add_trace( | |
| # go.Pie(labels=[channel_name_formating(channel_name) for channel_name in st.session_state['channels_list']] + ['Non Media'], | |
| # values= [scenario.channels[channel_name].actual_total_sales for channel_name in st.session_state['channels_list']] + [scenario.correction.sum() + scenario.constant.sum()], | |
| # hole=0.3), | |
| # row=1, col=2) | |
| # total_contribution_fig.update_traces(textposition='inside',texttemplate='%{percent:.1%}') | |
| # total_contribution_fig.update_layout(uniformtext_minsize=12,title='Channel contribution', uniformtext_mode='hide') | |
| # return total_contribution_fig | |
| # @st.cache(allow_output_mutation=True) | |
| # def create_contribuion_stacked_plot(scenario): | |
| # weekly_contribution_fig = make_subplots(rows=1, cols=2,subplot_titles=['Spends','Revenue'],specs=[[{"type": "bar"}, {"type": "bar"}]]) | |
| # raw_df = st.session_state['raw_df'] | |
| # df = raw_df.sort_values(by='Date') | |
| # x = df.Date | |
| # weekly_spends_data = [] | |
| # weekly_sales_data = [] | |
| # for channel_name in st.session_state['channels_list']: | |
| # weekly_spends_data.append((go.Bar(x=x, | |
| # y=scenario.channels[channel_name].actual_spends * scenario.channels[channel_name].conversion_rate, | |
| # name=channel_name_formating(channel_name), | |
| # hovertemplate="Date:%{x}<br>Spend:%{y:$.2s}", | |
| # legendgroup=channel_name))) | |
| # weekly_sales_data.append((go.Bar(x=x, | |
| # y=scenario.channels[channel_name].actual_sales, | |
| # name=channel_name_formating(channel_name), | |
| # hovertemplate="Date:%{x}<br>Revenue:%{y:$.2s}", | |
| # legendgroup=channel_name, showlegend=False))) | |
| # for _d in weekly_spends_data: | |
| # weekly_contribution_fig.add_trace(_d, row=1, col=1) | |
| # for _d in weekly_sales_data: | |
| # weekly_contribution_fig.add_trace(_d, row=1, col=2) | |
| # weekly_contribution_fig.add_trace(go.Bar(x=x, | |
| # y=scenario.constant + scenario.correction, | |
| # name='Non Media', | |
| # hovertemplate="Date:%{x}<br>Revenue:%{y:$.2s}"), row=1, col=2) | |
| # weekly_contribution_fig.update_layout(barmode='stack', title='Channel contribuion by week', xaxis_title='Date') | |
| # weekly_contribution_fig.update_xaxes(showgrid=False) | |
| # weekly_contribution_fig.update_yaxes(showgrid=False) | |
| # return weekly_contribution_fig | |
| # @st.cache(allow_output_mutation=True) | |
| # def create_channel_spends_sales_plot(channel): | |
| # if channel is not None: | |
| # x = channel.dates | |
| # _spends = channel.actual_spends * channel.conversion_rate | |
| # _sales = channel.actual_sales | |
| # channel_sales_spends_fig = make_subplots(specs=[[{"secondary_y": True}]]) | |
| # channel_sales_spends_fig.add_trace(go.Bar(x=x, y=_sales,marker_color='#c1f7dc',name='Revenue', hovertemplate="Date:%{x}<br>Revenue:%{y:$.2s}"), secondary_y = False) | |
| # channel_sales_spends_fig.add_trace(go.Scatter(x=x, y=_spends,line=dict(color='#005b96'),name='Spends',hovertemplate="Date:%{x}<br>Spend:%{y:$.2s}"), secondary_y = True) | |
| # channel_sales_spends_fig.update_layout(xaxis_title='Date',yaxis_title='Revenue',yaxis2_title='Spends ($)',title='Channel spends and Revenue week wise') | |
| # channel_sales_spends_fig.update_xaxes(showgrid=False) | |
| # channel_sales_spends_fig.update_yaxes(showgrid=False) | |
| # else: | |
| # raw_df = st.session_state['raw_df'] | |
| # df = raw_df.sort_values(by='Date') | |
| # x = df.Date | |
| # scenario = class_from_dict(st.session_state['default_scenario_dict']) | |
| # _sales = scenario.constant + scenario.correction | |
| # channel_sales_spends_fig = make_subplots(specs=[[{"secondary_y": True}]]) | |
| # channel_sales_spends_fig.add_trace(go.Bar(x=x, y=_sales,marker_color='#c1f7dc',name='Revenue', hovertemplate="Date:%{x}<br>Revenue:%{y:$.2s}"), secondary_y = False) | |
| # # channel_sales_spends_fig.add_trace(go.Scatter(x=x, y=_spends,line=dict(color='#15C39A'),name='Spends',hovertemplate="Date:%{x}<br>Spend:%{y:$.2s}"), secondary_y = True) | |
| # channel_sales_spends_fig.update_layout(xaxis_title='Date',yaxis_title='Revenue',yaxis2_title='Spends ($)',title='Channel spends and Revenue week wise') | |
| # channel_sales_spends_fig.update_xaxes(showgrid=False) | |
| # channel_sales_spends_fig.update_yaxes(showgrid=False) | |
| # return channel_sales_spends_fig | |
| # Define a shared color palette | |
| def create_contribution_pie(): | |
| color_palette = [ | |
| "#F3F3F0", | |
| "#5E7D7E", | |
| "#2FA1FF", | |
| "#00EDED", | |
| "#00EAE4", | |
| "#304550", | |
| "#EDEBEB", | |
| "#7FBEFD", | |
| "#003059", | |
| "#A2F3F3", | |
| "#E1D6E2", | |
| "#B6B6B6", | |
| ] | |
| total_contribution_fig = make_subplots( | |
| rows=1, | |
| cols=2, | |
| subplot_titles=["Spends", "Revenue"], | |
| specs=[[{"type": "pie"}, {"type": "pie"}]], | |
| ) | |
| channels_list = [ | |
| "Paid Search", | |
| "Ga will cid baixo risco", | |
| "Digital tactic others", | |
| "Fb la tier 1", | |
| "Fb la tier 2", | |
| "Paid social others", | |
| "Programmatic", | |
| "Kwai", | |
| "Indicacao", | |
| "Infleux", | |
| "Influencer", | |
| "Non Media", | |
| ] | |
| # Assign colors from the limited palette to channels | |
| colors_map = { | |
| col: color_palette[i % len(color_palette)] | |
| for i, col in enumerate(channels_list) | |
| } | |
| colors_map["Non Media"] = color_palette[ | |
| 5 | |
| ] # Assign fixed green color for 'Non Media' | |
| # Hardcoded values for Spends and Revenue | |
| spends_values = [0.5, 3.36, 1.1, 2.7, 2.7, 2.27, 70.6, 1, 1, 13.7, 1, 0] | |
| revenue_values = [1, 4, 5, 3, 3, 2, 50.8, 1.5, 0.7, 13, 0, 16] | |
| # Add trace for Spends pie chart | |
| total_contribution_fig.add_trace( | |
| go.Pie( | |
| labels=[channel_name for channel_name in channels_list], | |
| values=spends_values, | |
| marker=dict( | |
| colors=[colors_map[channel_name] for channel_name in channels_list] | |
| ), | |
| hole=0.3, | |
| ), | |
| row=1, | |
| col=1, | |
| ) | |
| # Add trace for Revenue pie chart | |
| total_contribution_fig.add_trace( | |
| go.Pie( | |
| labels=[channel_name for channel_name in channels_list], | |
| values=revenue_values, | |
| marker=dict( | |
| colors=[colors_map[channel_name] for channel_name in channels_list] | |
| ), | |
| hole=0.3, | |
| ), | |
| row=1, | |
| col=2, | |
| ) | |
| total_contribution_fig.update_traces( | |
| textposition="inside", texttemplate="%{percent:.1%}" | |
| ) | |
| total_contribution_fig.update_layout( | |
| uniformtext_minsize=12, | |
| title="Channel contribution", | |
| uniformtext_mode="hide", | |
| ) | |
| return total_contribution_fig | |
| def create_contribuion_stacked_plot(scenario): | |
| weekly_contribution_fig = make_subplots( | |
| rows=1, | |
| cols=2, | |
| subplot_titles=["Spends", "Revenue"], | |
| specs=[[{"type": "bar"}, {"type": "bar"}]], | |
| ) | |
| raw_df = st.session_state["raw_df"] | |
| df = raw_df.sort_values(by="Date") | |
| x = df.Date | |
| weekly_spends_data = [] | |
| weekly_sales_data = [] | |
| for i, channel_name in enumerate(st.session_state["channels_list"]): | |
| color = color_palette[i % len(color_palette)] | |
| weekly_spends_data.append( | |
| go.Bar( | |
| x=x, | |
| y=scenario.channels[channel_name].actual_spends | |
| * scenario.channels[channel_name].conversion_rate, | |
| name=channel_name_formating(channel_name), | |
| hovertemplate="Date:%{x}<br>Spend:%{y:$.2s}", | |
| legendgroup=channel_name, | |
| marker_color=color, | |
| ) | |
| ) | |
| weekly_sales_data.append( | |
| go.Bar( | |
| x=x, | |
| y=scenario.channels[channel_name].actual_sales, | |
| name=channel_name_formating(channel_name), | |
| hovertemplate="Date:%{x}<br>Revenue:%{y:$.2s}", | |
| legendgroup=channel_name, | |
| showlegend=False, | |
| marker_color=color, | |
| ) | |
| ) | |
| for _d in weekly_spends_data: | |
| weekly_contribution_fig.add_trace(_d, row=1, col=1) | |
| for _d in weekly_sales_data: | |
| weekly_contribution_fig.add_trace(_d, row=1, col=2) | |
| weekly_contribution_fig.add_trace( | |
| go.Bar( | |
| x=x, | |
| y=scenario.constant + scenario.correction, | |
| name="Non Media", | |
| hovertemplate="Date:%{x}<br>Revenue:%{y:$.2s}", | |
| marker_color=color_palette[-1], | |
| ), | |
| row=1, | |
| col=2, | |
| ) | |
| weekly_contribution_fig.update_layout( | |
| barmode="stack", | |
| title="Channel contribution by week", | |
| xaxis_title="Date", | |
| ) | |
| weekly_contribution_fig.update_xaxes(showgrid=False) | |
| weekly_contribution_fig.update_yaxes(showgrid=False) | |
| return weekly_contribution_fig | |
| def create_channel_spends_sales_plot(channel): | |
| if channel is not None: | |
| x = channel.dates | |
| _spends = channel.actual_spends * channel.conversion_rate | |
| _sales = channel.actual_sales | |
| channel_sales_spends_fig = make_subplots(specs=[[{"secondary_y": True}]]) | |
| channel_sales_spends_fig.add_trace( | |
| go.Bar( | |
| x=x, | |
| y=_sales, | |
| marker_color=color_palette[ | |
| 3 | |
| ], # You can choose a color from the palette | |
| name="Revenue", | |
| hovertemplate="Date:%{x}<br>Revenue:%{y:$.2s}", | |
| ), | |
| secondary_y=False, | |
| ) | |
| channel_sales_spends_fig.add_trace( | |
| go.Scatter( | |
| x=x, | |
| y=_spends, | |
| line=dict( | |
| color=color_palette[2] | |
| ), # You can choose another color from the palette | |
| name="Spends", | |
| hovertemplate="Date:%{x}<br>Spend:%{y:$.2s}", | |
| ), | |
| secondary_y=True, | |
| ) | |
| channel_sales_spends_fig.update_layout( | |
| xaxis_title="Date", | |
| yaxis_title="Revenue", | |
| yaxis2_title="Spends ($)", | |
| title="Channel spends and Revenue week-wise", | |
| ) | |
| channel_sales_spends_fig.update_xaxes(showgrid=False) | |
| channel_sales_spends_fig.update_yaxes(showgrid=False) | |
| else: | |
| raw_df = st.session_state["raw_df"] | |
| df = raw_df.sort_values(by="Date") | |
| x = df.Date | |
| scenario = class_from_dict(st.session_state["default_scenario_dict"]) | |
| _sales = scenario.constant + scenario.correction | |
| channel_sales_spends_fig = make_subplots(specs=[[{"secondary_y": True}]]) | |
| channel_sales_spends_fig.add_trace( | |
| go.Bar( | |
| x=x, | |
| y=_sales, | |
| marker_color=color_palette[ | |
| 0 | |
| ], # You can choose a color from the palette | |
| name="Revenue", | |
| hovertemplate="Date:%{x}<br>Revenue:%{y:$.2s}", | |
| ), | |
| secondary_y=False, | |
| ) | |
| channel_sales_spends_fig.update_layout( | |
| xaxis_title="Date", | |
| yaxis_title="Revenue", | |
| yaxis2_title="Spends ($)", | |
| title="Channel spends and Revenue week-wise", | |
| ) | |
| channel_sales_spends_fig.update_xaxes(showgrid=False) | |
| channel_sales_spends_fig.update_yaxes(showgrid=False) | |
| return channel_sales_spends_fig | |
| def format_numbers(value, n_decimals=1, include_indicator=True): | |
| if value is None: | |
| return None | |
| _value = value if value < 1 else numerize(value, n_decimals) | |
| if include_indicator: | |
| return f"{CURRENCY_INDICATOR} {_value}" | |
| else: | |
| return f"{_value}" | |
| def decimal_formater(num_string, n_decimals=1): | |
| parts = num_string.split(".") | |
| if len(parts) == 1: | |
| return num_string + "." + "0" * n_decimals | |
| else: | |
| to_be_padded = n_decimals - len(parts[-1]) | |
| if to_be_padded > 0: | |
| return num_string + "0" * to_be_padded | |
| else: | |
| return num_string | |
| def channel_name_formating(channel_name): | |
| name_mod = channel_name.replace("_", " ") | |
| if name_mod.lower().endswith(" imp"): | |
| name_mod = name_mod.replace("Imp", "Spend") | |
| elif name_mod.lower().endswith(" clicks"): | |
| name_mod = name_mod.replace("Clicks", "Spend") | |
| return name_mod | |
| def send_email(email, message): | |
| s = smtplib.SMTP("smtp.gmail.com", 587) | |
| s.starttls() | |
| s.login("geethu4444@gmail.com", "jgydhpfusuremcol") | |
| s.sendmail("geethu4444@gmail.com", email, message) | |
| s.quit() | |
| # if __name__ == "__main__": | |
| # initialize_data() | |
| ############################################################################################################# | |
| import os | |
| import json | |
| import streamlit as st | |
| # Function to get panels names | |
| def get_panels_names(file_selected): | |
| raw_data_df = st.session_state["project_dct"]["current_media_performance"][ | |
| "model_outputs" | |
| ][file_selected]["raw_data"] | |
| if "panel" in raw_data_df.columns: | |
| panel = list(set(raw_data_df["panel"])) | |
| elif "Panel" in raw_data_df.columns: | |
| panel = list(set(raw_data_df["Panel"])) | |
| else: | |
| panel = [] | |
| return panel + ["aggregated"] | |
| # Function to get metrics names | |
| def get_metrics_names(): | |
| return list( | |
| st.session_state["project_dct"]["current_media_performance"][ | |
| "model_outputs" | |
| ].keys() | |
| ) | |
| # Function to load the original and modified rcs metadata files into dictionaries | |
| def load_rcs_metadata_files(): | |
| original_data = st.session_state["project_dct"]["response_curves"][ | |
| "original_metadata_file" | |
| ] | |
| modified_data = st.session_state["project_dct"]["response_curves"][ | |
| "modified_metadata_file" | |
| ] | |
| return original_data, modified_data | |
| # Function to format name | |
| def name_formating(name): | |
| # Replace underscores with spaces | |
| name_mod = name.replace("_", " ") | |
| # Capitalize the first letter of each word | |
| name_mod = name_mod.title() | |
| return name_mod | |
| # Function to load the original and modified scenario metadata files into dictionaries | |
| def load_scenario_metadata_files(): | |
| original_data = st.session_state["project_dct"]["scenario_planner"][ | |
| "original_metadata_file" | |
| ] | |
| modified_data = st.session_state["project_dct"]["scenario_planner"][ | |
| "modified_metadata_file" | |
| ] | |
| return original_data, modified_data | |
| # Function to generate RCS data and store it as dictionary | |
| def generate_rcs_data(): | |
| # Retrieve the list of all metric names from the specified directory | |
| metrics_list = get_metrics_names() | |
| # Dictionary to store RCS data for all metrics and their respective panels | |
| all_rcs_data_original = {} | |
| all_rcs_data_modified = {} | |
| # Iterate over each metric in the metrics list | |
| for metric in metrics_list: | |
| # Retrieve the list of panel names from the current metric's Excel file | |
| panel_list = get_panels_names(file_selected=metric) | |
| # Check if rcs_data_modified exist | |
| if ( | |
| st.session_state["project_dct"]["response_curves"]["modified_metadata_file"] | |
| is not None | |
| ): | |
| modified_data = st.session_state["project_dct"]["response_curves"][ | |
| "modified_metadata_file" | |
| ] | |
| # Iterate over each panel in the panel list | |
| for panel in panel_list: | |
| # Initialize the original RCS data for the current panel and metric | |
| rcs_dict_original, scenario = initialize_data( | |
| panel=panel, | |
| metrics=metric, | |
| ) | |
| # Ensure the dictionary has the metric as a key for original data | |
| if metric not in all_rcs_data_original: | |
| all_rcs_data_original[metric] = {} | |
| # Store the original RCS data under the corresponding panel for the current metric | |
| all_rcs_data_original[metric][panel] = rcs_dict_original | |
| # Ensure the dictionary has the metric as a key for modified data | |
| if metric not in all_rcs_data_modified: | |
| all_rcs_data_modified[metric] = {} | |
| # Store the modified RCS data under the corresponding panel for the current metric | |
| for channel in rcs_dict_original: | |
| all_rcs_data_modified[metric][panel] = all_rcs_data_modified[ | |
| metric | |
| ].get(panel, {}) | |
| try: | |
| updated_rcs_dict = modified_data[metric][panel][channel] | |
| except: | |
| updated_rcs_dict = { | |
| "K": rcs_dict_original[channel]["K"], | |
| "b": rcs_dict_original[channel]["b"], | |
| "a": rcs_dict_original[channel]["a"], | |
| "x0": rcs_dict_original[channel]["x0"], | |
| } | |
| all_rcs_data_modified[metric][panel][channel] = updated_rcs_dict | |
| # Write the original RCS data | |
| st.session_state["project_dct"]["response_curves"][ | |
| "original_metadata_file" | |
| ] = all_rcs_data_original | |
| # Write the modified RCS data | |
| st.session_state["project_dct"]["response_curves"][ | |
| "modified_metadata_file" | |
| ] = all_rcs_data_modified | |
| # Function to generate scenario data and store it as dictionary | |
| def generate_scenario_data(): | |
| # Retrieve the list of all metric names from the specified directory | |
| metrics_list = get_metrics_names() | |
| # Dictionary to store scenario data for all metrics and their respective panels | |
| all_scenario_data_original = {} | |
| all_scenario_data_modified = {} | |
| # Iterate over each metric in the metrics list | |
| for metric in metrics_list: | |
| # Retrieve the list of panel names from the current metric's Excel file | |
| panel_list = get_panels_names(metric) | |
| # Check if scenario_data_modified exist | |
| if ( | |
| st.session_state["project_dct"]["scenario_planner"][ | |
| "modified_metadata_file" | |
| ] | |
| is not None | |
| ): | |
| modified_data = st.session_state["project_dct"]["scenario_planner"][ | |
| "modified_metadata_file" | |
| ] | |
| # Iterate over each panel in the panel list | |
| for panel in panel_list: | |
| # Initialize the original scenario data for the current panel and metric | |
| rcs_dict_original, scenario = initialize_data( | |
| panel=panel, | |
| metrics=metric, | |
| ) | |
| # Ensure the dictionary has the metric as a key for original data | |
| if metric not in all_scenario_data_original: | |
| all_scenario_data_original[metric] = {} | |
| # Store the original scenario data under the corresponding panel for the current metric | |
| all_scenario_data_original[metric][panel] = class_convert_to_dict(scenario) | |
| # Ensure the dictionary has the metric as a key for modified data | |
| if metric not in all_scenario_data_modified: | |
| all_scenario_data_modified[metric] = {} | |
| # Store the modified scenario data under the corresponding panel for the current metric | |
| try: | |
| all_scenario_data_modified[metric][panel] = modified_data[metric][panel] | |
| except: | |
| all_scenario_data_modified[metric][panel] = class_convert_to_dict( | |
| scenario | |
| ) | |
| # Write the original scenario data | |
| st.session_state["project_dct"]["scenario_planner"][ | |
| "original_metadata_file" | |
| ] = all_scenario_data_original | |
| # Write the modified scenario data | |
| st.session_state["project_dct"]["scenario_planner"][ | |
| "modified_metadata_file" | |
| ] = all_scenario_data_modified | |
| ############################################################################################################# | |