Spaces:
Runtime error
Runtime error
changes0
Browse files- DisentanglementBase.py +9 -9
- test_disentanglement.sh +1 -1
DisentanglementBase.py
CHANGED
|
@@ -175,8 +175,8 @@ class DisentanglementBase:
|
|
| 175 |
return X
|
| 176 |
|
| 177 |
def get_train_val(self, extremes=False):
|
| 178 |
-
X = self.get_encoded_latent()
|
| 179 |
y = np.array(self.df[self.variable].values)
|
|
|
|
| 180 |
if self.categorical:
|
| 181 |
bins = [(x-1) * 360 / (len(self.colors_list) - 1) if x != 1
|
| 182 |
else 1 for x in range(len(self.colors_list) + 1)]
|
|
@@ -443,14 +443,14 @@ class DisentanglementBase:
|
|
| 443 |
axs[i].imshow(image)
|
| 444 |
axs[i].set_title(np.round(lambd, 2))
|
| 445 |
plt.tight_layout()
|
| 446 |
-
plt.savefig(join(self.repo_folder, 'figures', '
|
| 447 |
plt.close()
|
| 448 |
|
| 449 |
if save_separately:
|
| 450 |
for i, (image, lambd) in enumerate(zip(images, lambdas)):
|
| 451 |
plt.imshow(image)
|
| 452 |
plt.tight_layout()
|
| 453 |
-
plt.savefig(join(self.repo_folder, 'figures', '
|
| 454 |
plt.close()
|
| 455 |
|
| 456 |
return images, lambdas
|
|
@@ -556,11 +556,11 @@ def continous_experiment(name, var, repo_folder, model, annotations, df, space,
|
|
| 556 |
|
| 557 |
def main():
|
| 558 |
repo_folder = '.'
|
| 559 |
-
annotations_file = join(repo_folder, 'data/textile_annotated_files/seeds0000-
|
| 560 |
with open(annotations_file, 'rb') as f:
|
| 561 |
annotations = pickle.load(f)
|
| 562 |
|
| 563 |
-
df_file = join(repo_folder, 'data/textile_annotated_files/
|
| 564 |
df = pd.read_csv(df_file).fillna('#000000')
|
| 565 |
|
| 566 |
model_file = join(repo_folder, 'data/textile_model_files/network-snapshot-005000.pkl')
|
|
@@ -571,7 +571,7 @@ def main():
|
|
| 571 |
'Kelly Green', 'Green Blue Seafoam', 'Cyan Blue',
|
| 572 |
'Warm Blue', 'Indigo', 'Purple Magenta', 'Magenta Pink']
|
| 573 |
colors_list = ['Gray', 'Red Orange', 'Yellow', 'Green', 'Light Blue',
|
| 574 |
-
'Blue', '
|
| 575 |
|
| 576 |
scores = []
|
| 577 |
kwargs = {'CL method':['LR', 'SVM'], 'C':[0.1, 1], 'sign':[True, False],
|
|
@@ -584,12 +584,12 @@ def main():
|
|
| 584 |
if specific_examples is not None:
|
| 585 |
disentanglemnet_exp = DisentanglementBase(repo_folder, model, annotations, df, space='w', colors_list=colors_list, compute_s=False)
|
| 586 |
|
| 587 |
-
separation_vectors = disentanglemnet_exp.StyleSpace_separation_vector(sign=True, num_factors=10, cutout=None)
|
| 588 |
-
|
| 589 |
for specific_example in specific_examples:
|
| 590 |
seed = specific_example
|
| 591 |
for i, color in enumerate(colors_list):
|
| 592 |
-
disentanglemnet_exp.generate_changes(seed, separation_vectors[i], min_epsilon=-
|
| 593 |
|
| 594 |
return
|
| 595 |
|
|
|
|
| 175 |
return X
|
| 176 |
|
| 177 |
def get_train_val(self, extremes=False):
|
|
|
|
| 178 |
y = np.array(self.df[self.variable].values)
|
| 179 |
+
X = self.get_encoded_latent()[:y.shape[0], :]
|
| 180 |
if self.categorical:
|
| 181 |
bins = [(x-1) * 360 / (len(self.colors_list) - 1) if x != 1
|
| 182 |
else 1 for x in range(len(self.colors_list) + 1)]
|
|
|
|
| 443 |
axs[i].imshow(image)
|
| 444 |
axs[i].set_title(np.round(lambd, 2))
|
| 445 |
plt.tight_layout()
|
| 446 |
+
plt.savefig(join(self.repo_folder, 'figures', 'examples_new', name+'.jpg'))
|
| 447 |
plt.close()
|
| 448 |
|
| 449 |
if save_separately:
|
| 450 |
for i, (image, lambd) in enumerate(zip(images, lambdas)):
|
| 451 |
plt.imshow(image)
|
| 452 |
plt.tight_layout()
|
| 453 |
+
plt.savefig(join(self.repo_folder, 'figures', 'examples_new', name + '_' + str(lambd) + '.jpg'))
|
| 454 |
plt.close()
|
| 455 |
|
| 456 |
return images, lambdas
|
|
|
|
| 556 |
|
| 557 |
def main():
|
| 558 |
repo_folder = '.'
|
| 559 |
+
annotations_file = join(repo_folder, 'data/textile_annotated_files/seeds0000-1000000.pkl')
|
| 560 |
with open(annotations_file, 'rb') as f:
|
| 561 |
annotations = pickle.load(f)
|
| 562 |
|
| 563 |
+
df_file = join(repo_folder, 'data/textile_annotated_files/top_three_colours_00000-730003.csv')
|
| 564 |
df = pd.read_csv(df_file).fillna('#000000')
|
| 565 |
|
| 566 |
model_file = join(repo_folder, 'data/textile_model_files/network-snapshot-005000.pkl')
|
|
|
|
| 571 |
'Kelly Green', 'Green Blue Seafoam', 'Cyan Blue',
|
| 572 |
'Warm Blue', 'Indigo', 'Purple Magenta', 'Magenta Pink']
|
| 573 |
colors_list = ['Gray', 'Red Orange', 'Yellow', 'Green', 'Light Blue',
|
| 574 |
+
'Blue', 'Violet', 'Pink']
|
| 575 |
|
| 576 |
scores = []
|
| 577 |
kwargs = {'CL method':['LR', 'SVM'], 'C':[0.1, 1], 'sign':[True, False],
|
|
|
|
| 584 |
if specific_examples is not None:
|
| 585 |
disentanglemnet_exp = DisentanglementBase(repo_folder, model, annotations, df, space='w', colors_list=colors_list, compute_s=False)
|
| 586 |
|
| 587 |
+
# separation_vectors = disentanglemnet_exp.StyleSpace_separation_vector(sign=True, num_factors=10, cutout=None)
|
| 588 |
+
separation_vectors = disentanglemnet_exp.InterFaceGAN_separation_vector(method='LR', C=0.1)
|
| 589 |
for specific_example in specific_examples:
|
| 590 |
seed = specific_example
|
| 591 |
for i, color in enumerate(colors_list):
|
| 592 |
+
disentanglemnet_exp.generate_changes(seed, separation_vectors[i], min_epsilon=-18, max_epsilon=18, savefig=True, save_separately=True, feature=color, method='InterFaceGAN' + '_' + str('LR') + '_' + str(0.1) + '_' + str(None))
|
| 593 |
|
| 594 |
return
|
| 595 |
|
test_disentanglement.sh
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
#!/bin/bash
|
| 2 |
-
#SBATCH --time=
|
| 3 |
#SBATCH --mem=32GB
|
| 4 |
#SBATCH --gres gpu:1
|
| 5 |
|
|
|
|
| 1 |
#!/bin/bash
|
| 2 |
+
#SBATCH --time=02:00:00
|
| 3 |
#SBATCH --mem=32GB
|
| 4 |
#SBATCH --gres gpu:1
|
| 5 |
|