Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions perturb_EB.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
import jsonlines
from tqdm import tqdm
from transformers import PegasusForConditionalGeneration, PegasusTokenizer
import torch
import numpy as np

gold_path = 'entailment_bank/data/public_dataset/entailment_trees_emnlp2021_data_v2/dataset/task_1/test.jsonl'
more_path = 'entailment_bank/data/public_dataset/entailment_trees_emnlp2021_data_v2/dataset/task_2/test.jsonl'
Expand Down Expand Up @@ -73,7 +75,7 @@ def reconstruct_proof(steps, sentences):
proof += '; '
return proof

def repeat_steps(in_steps, in_sentences):
def repeat_steps(in_steps, in_sentences, *args):
steps = deepcopy(in_steps)
sentences = deepcopy(in_sentences)
int_idxs = [s for s in range(len(steps)) if 'int' in steps[s]['child']]
Expand All @@ -87,7 +89,7 @@ def repeat_steps(in_steps, in_sentences):
steps.insert(idx + 1, {'parents': [key], 'child': repeated_node})
return steps, sentences

def delete_steps(in_steps, in_sentences):
def delete_steps(in_steps, in_sentences, *args):
steps = deepcopy(in_steps)
sentences = deepcopy(in_sentences)
int_idxs = [s for s in range(len(steps)) if 'int' in steps[s]['child']]
Expand All @@ -102,7 +104,7 @@ def delete_steps(in_steps, in_sentences):
step['parents'] = [p for p in step['parents'] if p != del_node]
return steps, sentences

def swapped_steps(in_steps, in_sentences):
def swapped_steps(in_steps, in_sentences, *args):
steps = deepcopy(in_steps)
sentences = deepcopy(in_sentences)
int_idxs = [s for s in range(len(steps)) if 'int' in steps[s]['child']]
Expand All @@ -121,7 +123,7 @@ def swapped_steps(in_steps, in_sentences):
step['parents'] = [p for p in step['parents'] if p != swap_node]
return steps, sentences

def negate_step(in_steps, in_sentences):
def negate_step(in_steps, in_sentences, *args):
steps = deepcopy(in_steps)
sentences = deepcopy(in_sentences)
int_idxs = [s for s in range(len(steps)) if 'int' in steps[s]['child']]
Expand All @@ -145,7 +147,7 @@ def hallucinate_step(in_steps, in_sentences, extra_sentences):
print(sentences[hallucinate_node])
return steps, sentences

def paraphrase_steps(in_steps, in_sentences):
def paraphrase_steps(in_steps, in_sentences, *args):
steps = deepcopy(in_steps)
sentences = deepcopy(in_sentences)
int_idxs = [s for s in range(len(steps)) if 'int' in steps[s]['child']]
Expand Down Expand Up @@ -209,7 +211,8 @@ def redundant_steps(in_steps, in_sentences, extra_sentences):
perturbed = False
custom_unperturbed_ids[perturb_type + "_test.jsonl"].append(id)
else: perturbed = True
tree_entry.append({'perturbed': perturbed, 'perturbations': perturb_type, 'steps':{'original': steps, 'perturbed': perturbed_steps}, 'sentences':{'original': sentences, 'perturbed': perturbed_sentences}, 'written':{'original': original_written_steps, 'perturbed': written_steps}, 'question': question, 'answer': answer})
# tree_entry.append({'perturbed': perturbed, 'perturbations': perturb_type, 'steps':{'original': steps, 'perturbed': perturbed_steps}, 'sentences':{'original': sentences, 'perturbed': perturbed_sentences}, 'written':{'original': original_written_steps, 'perturbed': written_steps}, 'question': question, 'answer': answer})
tree_entry.append({'perturbed': perturbed, 'perturbations': perturb_type, 'steps':{'original': steps, 'perturbed': perturbed_steps}, 'sentences':{'original': sentences, 'perturbed': perturbed_sentences}, 'question': question, 'answer': answer})
with jsonlines.open(os.path.join(tree_dest_path, fname), 'w') as writer:
writer.write_all(tree_entry)

Expand Down