2525class explain ():
2626 def __init__ (self ):
2727 super (explain , self ).__init__ ()
28- self .param = None
28+ self .param = {}
2929
3030 # is classification function?
3131
@@ -47,8 +47,6 @@ def ai(self, df, y, model, model_name="xgboost", mode=None):
4747 y_variable_predict = "y_prediction"
4848
4949
50- # is classification?
51- is_classification = self .is_classification_given_y_array (y )
5250
5351 # If yes, then different shap functuions are required.
5452 # get the shap value based on predcton and make a new dataframe.
@@ -66,10 +64,14 @@ def ai(self, df, y, model, model_name="xgboost", mode=None):
6664 else :
6765 prediction_col = model .predict (df .to_numpy ())
6866
67+ # is classification?
68+ is_classification = self .is_classification_given_y_array (prediction_col )
69+
70+
6971
7072 #shap
7173 c = calculate_shap ()
72- self .df_final = c .find (model , df , prediction_col , is_classification , model_name = model_name )
74+ self .df_final , self . explainer = c .find (model , df , prediction_col , is_classification , model_name = model_name )
7375
7476 #prediction col
7577 self .df_final [y_variable_predict ] = prediction_col
@@ -78,8 +80,43 @@ def ai(self, df, y, model, model_name="xgboost", mode=None):
7880
7981 self .df_final [y_variable ] = y
8082
83+
84+ #additional inputs.
85+ if is_classification == True :
86+ # find and add probabilities in the dataset.
87+ prediction_col_prob = model .predict_proba (df .to_numpy ())
88+ pd_prediction_col_prob = pd .DataFrame (prediction_col_prob )
89+
90+ for c in pd_prediction_col_prob .columns :
91+ self .df_final ["probability_of_predicting_class_" + str (c )] = list (pd_prediction_col_prob [c ])
92+
93+ classes = []
94+ for c in pd_prediction_col_prob .columns :
95+ classes .append (str (c ))
96+ self .param ["classes" ]= classes
97+
98+ try :
99+ expected_values_by_class = self .explainer .expected_value
100+ except :
101+ expected_values_by_class = []
102+ for c in range (len (classes )):
103+ expected_values_by_class .append (1 / len (classes ))
104+
105+
106+ self .param ["expected_values" ]= expected_values_by_class
107+ else :
108+ try :
109+ expected_values = self .explainer .expected_value
110+ self .param ["expected_values" ] = [expected_values ]
111+ except :
112+ expected_value = [round (np .array (y ).mean (),2 )]
113+ self .param ["expected_values" ] = expected_value
114+
115+
116+ self .param ["is_classification" ]= is_classification
117+
81118 d = dashboard ()
82- d .find (self .df_final , y_variable , y_variable_predict , mode )
119+ d .find (self .df_final , y_variable , y_variable_predict , mode , self . param )
83120
84121 return True
85122
0 commit comments