@@ -1026,7 +1026,9 @@ def get_prediction(self, x_pos, x_neg, return_normalize_scores=False):
1026
1026
return prediction
1027
1027
1028
1028
def fit_one (self , pos : Union [List [OWLNamedIndividual ], List [str ]], neg : Union [List [OWLNamedIndividual ], List [str ]]):
1029
-
1029
+ def simple_strategy (strategy : SimpleSolution , prediction : List [str ]):
1030
+ return self .dl_parser .parse (strategy .predict (prediction ))
1031
+
1030
1032
if isinstance (pos [0 ], OWLNamedIndividual ):
1031
1033
pos_str = [ind .str .split ("/" )[- 1 ] for ind in pos ]
1032
1034
neg_str = [ind .str .split ("/" )[- 1 ] for ind in neg ]
@@ -1049,41 +1051,33 @@ def fit_one(self, pos: Union[List[OWLNamedIndividual], List[str]], neg: Union[Li
1049
1051
x_pos , x_neg = next (iter (dataloader ))
1050
1052
simpleSolution = SimpleSolution (list (self .vocab ), self .atomic_concept_names )
1051
1053
predictions_raw = self .get_prediction (x_pos , x_neg )
1052
- counter = 0
1053
-
1054
+
1055
+ if self .enforce_validity :
1056
+ concept_ast_builder = ConceptAbstractSyntaxTreeBuilder (knowledge_base = self .knowledge_base )
1057
+
1054
1058
predictions = []
1055
1059
for prediction in predictions_raw :
1060
+ prediction_str = "" .join (before_pad (prediction .squeeze ()))
1056
1061
try :
1057
- prediction_str = "" .join (before_pad (prediction .squeeze ()))
1058
1062
concept = self .dl_parser .parse (prediction_str )
1059
- counter += 1
1060
1063
except :
1061
- prediction_str = simpleSolution .predict ("" .join (before_pad (prediction .squeeze ())))
1062
- concept = self .dl_parser .parse (prediction_str )
1063
- if self .verbose > 0 :
1064
- print ("Prediction: " , prediction_str )
1064
+ if self .enforce_validity :
1065
+ try :
1066
+ raw_prediction = [pred for pred in prediction if pred != 'PAD' ]
1067
+ parse_concept_str , _ = concept_ast_builder .parse (token_sequence = raw_prediction , enforce_validity = True )
1068
+
1069
+ try :
1070
+ concept = self .dl_parser .parse (parse_concept_str )
1071
+ except :
1072
+ prediction_str = simpleSolution .predict (prediction_str )
1073
+ concept = self .dl_parser .parse (prediction_str )
1074
+ except :
1075
+ concept = simple_strategy (simpleSolution , prediction_str )
1076
+ else :
1077
+ concept = simple_strategy (simpleSolution , prediction_str )
1078
+ if self .verbose > 0 :
1079
+ print ("Prediction: " , prediction_str )
1065
1080
predictions .append (concept )
1066
- # else:
1067
- # concept_abstract_syntax_tree_builder = ConceptAbstractSyntaxTreeBuilder(knowledge_base=self.knowledge_base)
1068
- # # prediction_scores = self.get_prediction(x_pos, x_neg, return_normalize_scores=True)
1069
-
1070
- # # decoder = DecodingStrategy(prediction_scores, self.inv_vocab, self.num_predictions, self.max_length)
1071
- # # decoded_raw_predictions = decoder.decode(strategy_type=self.enforce_validity) #TODO Handle kwargs
1072
- # unpad_raw_predictions = [[pred for pred in pred_seq if pred != 'PAD'] for pred_seq in predictions_raw]
1073
-
1074
- # for prediction in unpad_raw_predictions:
1075
- # try:
1076
- # parse_concept_str, _ = concept_abstract_syntax_tree_builder.parse(token_sequence=prediction, enforce_validity=True)
1077
-
1078
- # try:
1079
- # concept = self.dl_parser.parse(parse_concept_str)
1080
- # counter += 1
1081
- # except:
1082
- # pass
1083
- # predictions.append(concept)
1084
- # except Exception as e:
1085
- # pass
1086
- # print(f"{counter/self.num_predictions:.2f}") #average valid
1087
1081
return predictions
1088
1082
1089
1083
def fit (self , learning_problem : PosNegLPStandard , ** kwargs ):
0 commit comments