Skip to content

Commit 7a38086

Browse files
committed
feat: add modified validity
1 parent 197b6b5 commit 7a38086

File tree

1 file changed

+24
-30
lines changed

1 file changed

+24
-30
lines changed

ontolearn/concept_learner.py

Lines changed: 24 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1026,7 +1026,9 @@ def get_prediction(self, x_pos, x_neg, return_normalize_scores=False):
10261026
return prediction
10271027

10281028
def fit_one(self, pos: Union[List[OWLNamedIndividual], List[str]], neg: Union[List[OWLNamedIndividual], List[str]]):
1029-
1029+
def simple_strategy(strategy: SimpleSolution, prediction: List[str]):
1030+
return self.dl_parser.parse(strategy.predict(prediction))
1031+
10301032
if isinstance(pos[0], OWLNamedIndividual):
10311033
pos_str = [ind.str.split("/")[-1] for ind in pos]
10321034
neg_str = [ind.str.split("/")[-1] for ind in neg]
@@ -1049,41 +1051,33 @@ def fit_one(self, pos: Union[List[OWLNamedIndividual], List[str]], neg: Union[Li
10491051
x_pos, x_neg = next(iter(dataloader))
10501052
simpleSolution = SimpleSolution(list(self.vocab), self.atomic_concept_names)
10511053
predictions_raw = self.get_prediction(x_pos, x_neg)
1052-
counter = 0
1053-
1054+
1055+
if self.enforce_validity:
1056+
concept_ast_builder = ConceptAbstractSyntaxTreeBuilder(knowledge_base=self.knowledge_base)
1057+
10541058
predictions = []
10551059
for prediction in predictions_raw:
1060+
prediction_str = "".join(before_pad(prediction.squeeze()))
10561061
try:
1057-
prediction_str = "".join(before_pad(prediction.squeeze()))
10581062
concept = self.dl_parser.parse(prediction_str)
1059-
counter +=1
10601063
except:
1061-
prediction_str = simpleSolution.predict("".join(before_pad(prediction.squeeze())))
1062-
concept = self.dl_parser.parse(prediction_str)
1063-
if self.verbose>0:
1064-
print("Prediction: ", prediction_str)
1064+
if self.enforce_validity:
1065+
try:
1066+
raw_prediction = [pred for pred in prediction if pred != 'PAD']
1067+
parse_concept_str, _ = concept_ast_builder.parse(token_sequence=raw_prediction, enforce_validity=True)
1068+
1069+
try:
1070+
concept = self.dl_parser.parse(parse_concept_str)
1071+
except:
1072+
prediction_str = simpleSolution.predict(prediction_str)
1073+
concept = self.dl_parser.parse(prediction_str)
1074+
except:
1075+
concept = simple_strategy(simpleSolution, prediction_str)
1076+
else:
1077+
concept = simple_strategy(simpleSolution, prediction_str)
1078+
if self.verbose>0:
1079+
print("Prediction: ", prediction_str)
10651080
predictions.append(concept)
1066-
# else:
1067-
# concept_abstract_syntax_tree_builder = ConceptAbstractSyntaxTreeBuilder(knowledge_base=self.knowledge_base)
1068-
# # prediction_scores = self.get_prediction(x_pos, x_neg, return_normalize_scores=True)
1069-
1070-
# # decoder = DecodingStrategy(prediction_scores, self.inv_vocab, self.num_predictions, self.max_length)
1071-
# # decoded_raw_predictions = decoder.decode(strategy_type=self.enforce_validity) #TODO Handle kwargs
1072-
# unpad_raw_predictions = [[pred for pred in pred_seq if pred != 'PAD'] for pred_seq in predictions_raw]
1073-
1074-
# for prediction in unpad_raw_predictions:
1075-
# try:
1076-
# parse_concept_str, _ = concept_abstract_syntax_tree_builder.parse(token_sequence=prediction, enforce_validity=True)
1077-
1078-
# try:
1079-
# concept = self.dl_parser.parse(parse_concept_str)
1080-
# counter += 1
1081-
# except:
1082-
# pass
1083-
# predictions.append(concept)
1084-
# except Exception as e:
1085-
# pass
1086-
# print(f"{counter/self.num_predictions:.2f}") #average valid
10871081
return predictions
10881082

10891083
def fit(self, learning_problem: PosNegLPStandard, **kwargs):

0 commit comments

Comments
 (0)