88hours commited on
Commit
60e35a0
·
1 Parent(s): 6255502

Must Refactor Code :)

Browse files
Files changed (1) hide show
  1. s5-how-to-umap.py +48 -35
s5-how-to-umap.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from IPython.display import display
2
  from umap import UMAP
3
  from sklearn.preprocessing import MinMaxScaler
@@ -9,7 +10,8 @@ from s2_download_data import load_data_from_huggingface
9
  from utils import prepare_dataset_for_umap_visualization as data_prep
10
  from s3_data_to_vector_embedding import bt_embeddings_from_local
11
  import random
12
-
 
13
  # prompt templates
14
  templates = [
15
  'a picture of {}',
@@ -35,54 +37,63 @@ def data_prep(hf_dataset_name, templates=templates, test_size=1000):
35
  })
36
  return img_txt_pairs
37
 
38
- # prepare image_text pairs
 
 
 
39
 
40
- # for the first 50 data of Huggingface dataset
41
- # "yashikota/cat-image-dataset"
42
- cat_img_txt_pairs = data_prep("yashikota/cat-image-dataset",
43
- "cat", test_size=50)
44
 
45
- # for the first 50 data of Huggingface dataset
46
- # "tanganke/stanford_cars"
47
- car_img_txt_pairs = data_prep("tanganke/stanford_cars",
48
- "car", test_size=50)
49
 
50
- # display an example of a cat image-text pair data
51
- display(cat_img_txt_pairs[0]['caption'])
52
- display(cat_img_txt_pairs[0]['pil_img'])
53
 
54
- # display an example of a car image-text pair data
55
- display(car_img_txt_pairs[0]['caption'])
56
- display(car_img_txt_pairs[0]['pil_img'])
 
 
 
57
 
58
- # compute BridgeTower embeddings for cat image-text pairs
59
- def load_cat_and_car_embeddings():
60
-
61
  def load_embeddings(img_txt_pair):
62
  pil_img = img_txt_pair['pil_img']
63
  caption = img_txt_pair['caption']
64
  return bt_embeddings_from_local(caption, pil_img)
65
 
66
- cat_embeddings = []
67
- for img_txt_pair in tqdm(
 
68
  cat_img_txt_pairs,
69
  total=len(cat_img_txt_pairs)
70
  ):
71
- pil_img = img_txt_pair['pil_img']
72
- caption = img_txt_pair['caption']
73
- embedding =load_embeddings(caption, pil_img)
74
- cat_embeddings.append(embedding)
 
 
 
75
 
76
- # compute BridgeTower embeddings for car image-text pairs
77
  car_embeddings = []
78
- for img_txt_pair in tqdm(
79
- car_img_txt_pairs,
80
- total=len(car_img_txt_pairs)
81
- ):
82
- pil_img = img_txt_pair['pil_img']
83
- caption = img_txt_pair['caption']
84
- embedding = load_embeddings(caption, pil_img)
85
- car_embeddings.append(embedding)
 
 
86
  return cat_embeddings, car_embeddings
87
 
88
 
@@ -123,4 +134,6 @@ def show_umap_visualization():
123
  plt.title('Scatter plot of images of cats and cars using UMAP')
124
  plt.xlabel('X')
125
  plt.ylabel('Y')
126
- plt.show()
 
 
 
1
+ from os import path
2
  from IPython.display import display
3
  from umap import UMAP
4
  from sklearn.preprocessing import MinMaxScaler
 
10
  from utils import prepare_dataset_for_umap_visualization as data_prep
11
  from s3_data_to_vector_embedding import bt_embeddings_from_local
12
  import random
13
+ import numpy as np
14
+ import torch
15
  # prompt templates
16
  templates = [
17
  'a picture of {}',
 
37
  })
38
  return img_txt_pairs
39
 
40
+ # compute BridgeTower embeddings for cat image-text pairs
41
+ def load_cat_and_car_embeddings():
42
+
43
+ # prepare image_text pairs
44
 
45
+ # for the first 50 data of Huggingface dataset
46
+ # "yashikota/cat-image-dataset"
47
+ cat_img_txt_pairs = data_prep("yashikota/cat-image-dataset",
48
+ "cat", test_size=50)
49
 
50
+ # for the first 50 data of Huggingface dataset
51
+ # "tanganke/stanford_cars"
52
+ car_img_txt_pairs = data_prep("tanganke/stanford_cars",
53
+ "car", test_size=50)
54
 
55
+ # display an example of a cat image-text pair data
56
+ display(cat_img_txt_pairs[0]['caption'])
57
+ display(cat_img_txt_pairs[0]['pil_img'])
58
 
59
+ # display an example of a car image-text pair data
60
+ display(car_img_txt_pairs[0]['caption'])
61
+ display(car_img_txt_pairs[0]['pil_img'])
62
+
63
+ def save_embeddings(embedding, path):
64
+ torch.save(embedding, path)
65
 
 
 
 
66
  def load_embeddings(img_txt_pair):
67
  pil_img = img_txt_pair['pil_img']
68
  caption = img_txt_pair['caption']
69
  return bt_embeddings_from_local(caption, pil_img)
70
 
71
+ def load_all_embeddings_from_image_text_pairs(file_name):
72
+ cat_embeddings = []
73
+ for img_txt_pair in tqdm(
74
  cat_img_txt_pairs,
75
  total=len(cat_img_txt_pairs)
76
  ):
77
+ pil_img = img_txt_pair['pil_img']
78
+ caption = img_txt_pair['caption']
79
+ embedding = load_embeddings(caption, pil_img)
80
+ cat_embeddings.append(embedding)
81
+ save_embeddings(cat_embeddings, file_name)
82
+ return cat_embeddings
83
+
84
 
85
+ cat_embeddings = []
86
  car_embeddings = []
87
+ if (path.exists('./shared_data/cat_embeddings.pt')):
88
+ cat_embeddings = torch.load('./shared_data/cat_embeddings.pt')
89
+ else:
90
+ cat_embeddings = load_all_embeddings_from_image_text_pairs('./shared_data/cat_embeddings.pt')
91
+
92
+ if (path.exists('./shared_data/car_embeddings.pt')):
93
+ car_embeddings = torch.load('./shared_data/car_embeddings.pt')
94
+ else:
95
+ car_embeddings = load_all_embeddings_from_image_text_pairs('./shared_data/car_embeddings.pt')
96
+
97
  return cat_embeddings, car_embeddings
98
 
99
 
 
134
  plt.title('Scatter plot of images of cats and cars using UMAP')
135
  plt.xlabel('X')
136
  plt.ylabel('Y')
137
+ plt.show()
138
+
139
+ load_cat_and_car_embeddings()