wolfofbackstreet commited on
Commit
7e8e3a6
·
verified ·
1 Parent(s): 3c9009e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +90 -91
app.py CHANGED
@@ -1,92 +1,91 @@
1
- import os
2
- from flask import Flask, request, jsonify, send_file
3
- from optimum.intel.openvino.modeling_diffusion import OVStableDiffusionPipeline
4
- from PIL import Image
5
- import io
6
- import torch
7
- import logging
8
-
9
- # Set up logging
10
- logging.basicConfig(level=logging.INFO)
11
- logger = logging.getLogger(__name__)
12
-
13
- app = Flask(__name__)
14
-
15
- # Set cache directories
16
- os.environ["HF_HOME"] = "/app/cache/huggingface"
17
- os.environ["MPLCONFIGDIR"] = "/app/matplotlib_cache"
18
- os.environ["OPENVINO_TELEMETRY_DIR"] = "/app/openvino_cache"
19
-
20
- # Ensure cache directories exist
21
- for cache_dir in ["/app/cache/huggingface", "/app/matplotlib_cache", "/app/openvino_cache"]:
22
- os.makedirs(cache_dir, exist_ok=True)
23
-
24
- # Load the base pre-converted OpenVINO SDXL model
25
- base_model_id = "rupeshs/hyper-sd-sdxl-1-step-openvino-int8"
26
- try:
27
- pipeline = OVStableDiffusionPipeline.from_pretrained(
28
- base_model_id,
29
- ov_config={"CACHE_DIR": "/app/cache/openvino"},
30
- device="CPU"
31
- )
32
- pipeline.enable_tiny_auto_encoder()
33
- logger.info("Base model loaded successfully")
34
- except Exception as e:
35
- logger.error(f"Failed to load base model: {str(e)}")
36
- raise
37
-
38
- @app.route('/generate', methods=['POST'])
39
- def generate_image():
40
- try:
41
- # Get parameters from request
42
- data = request.get_json()
43
- prompt = data.get('prompt', 'A futuristic cityscape at sunset, cyberpunk style, 8k')
44
- width = data.get('width', 512)
45
- height = data.get('height', 512)
46
- num_inference_steps = data.get('num_inference_steps', 4)
47
- guidance_scale = data.get('guidance_scale', 1.0)
48
- # lora_model_id = data.get('lora_model_id', None)
49
- # lora_weight = data.get('lora_weight', 0.8)
50
-
51
- # Load LoRA weights if specified
52
- local_pipeline = pipeline
53
- # if lora_model_id:
54
- # try:
55
- # local_pipeline = LoraLoaderMixin.load_lora_weights(
56
- # local_pipeline,
57
- # lora_model_id,
58
- # lora_scale=lora_weight,
59
- # cache_dir="/app/cache/huggingface"
60
- # )
61
- # logger.info(f"LoRA model {lora_model_id} loaded successfully")
62
- # except Exception as e:
63
- # logger.error(f"Failed to load LoRA model: {str(e)}")
64
- # return jsonify({'error': f"Failed to load LoRA model: {str(e)}"}), 400
65
-
66
- # Generate image
67
- image = local_pipeline(
68
- prompt=prompt,
69
- width=width,
70
- height=height,
71
- num_inference_steps=num_inference_steps,
72
- guidance_scale=guidance_scale
73
- ).images[0]
74
-
75
- # Save image to a bytes buffer
76
- img_io = io.BytesIO()
77
- image.save(img_io, 'PNG')
78
- img_io.seek(0)
79
-
80
- return send_file(
81
- img_io,
82
- mimetype='image/png',
83
- as_attachment=True,
84
- download_name='generated_image.png'
85
- )
86
- except Exception as e:
87
- logger.error(f"Image generation failed: {str(e)}")
88
- return jsonify({'error': str(e)}), 500
89
-
90
- if __name__ == '__main__':
91
- port = int(os.getenv('PORT', 7860))
92
  app.run(host='0.0.0.0', port=port)
 
1
+ import os
2
+ from flask import Flask, request, jsonify, send_file
3
+ from optimum.intel.openvino.modeling_diffusion import OVStableDiffusionPipeline
4
+ from PIL import Image
5
+ import io
6
+ import torch
7
+ import logging
8
+
9
+ # Set up logging
10
+ logging.basicConfig(level=logging.INFO)
11
+ logger = logging.getLogger(__name__)
12
+
13
+ app = Flask(__name__)
14
+
15
+ # Set cache directories
16
+ os.environ["HF_HOME"] = "/app/cache/huggingface"
17
+ os.environ["MPLCONFIGDIR"] = "/app/matplotlib_cache"
18
+ os.environ["OPENVINO_TELEMETRY_DIR"] = "/app/openvino_cache"
19
+
20
+ # Ensure cache directories exist
21
+ for cache_dir in ["/app/cache/huggingface", "/app/matplotlib_cache", "/app/openvino_cache"]:
22
+ os.makedirs(cache_dir, exist_ok=True)
23
+
24
+ # Load the base pre-converted OpenVINO SDXL model
25
+ base_model_id = "rupeshs/hyper-sd-sdxl-1-step-openvino-int8"
26
+ try:
27
+ pipeline = OVStableDiffusionPipeline.from_pretrained(
28
+ base_model_id,
29
+ ov_config={"CACHE_DIR": "/app/cache/openvino"},
30
+ device="CPU"
31
+ )
32
+ logger.info("Base model loaded successfully")
33
+ except Exception as e:
34
+ logger.error(f"Failed to load base model: {str(e)}")
35
+ raise
36
+
37
+ @app.route('/generate', methods=['POST'])
38
+ def generate_image():
39
+ try:
40
+ # Get parameters from request
41
+ data = request.get_json()
42
+ prompt = data.get('prompt', 'A futuristic cityscape at sunset, cyberpunk style, 8k')
43
+ width = data.get('width', 512)
44
+ height = data.get('height', 512)
45
+ num_inference_steps = data.get('num_inference_steps', 4)
46
+ guidance_scale = data.get('guidance_scale', 1.0)
47
+ # lora_model_id = data.get('lora_model_id', None)
48
+ # lora_weight = data.get('lora_weight', 0.8)
49
+
50
+ # Load LoRA weights if specified
51
+ local_pipeline = pipeline
52
+ # if lora_model_id:
53
+ # try:
54
+ # local_pipeline = LoraLoaderMixin.load_lora_weights(
55
+ # local_pipeline,
56
+ # lora_model_id,
57
+ # lora_scale=lora_weight,
58
+ # cache_dir="/app/cache/huggingface"
59
+ # )
60
+ # logger.info(f"LoRA model {lora_model_id} loaded successfully")
61
+ # except Exception as e:
62
+ # logger.error(f"Failed to load LoRA model: {str(e)}")
63
+ # return jsonify({'error': f"Failed to load LoRA model: {str(e)}"}), 400
64
+
65
+ # Generate image
66
+ image = local_pipeline(
67
+ prompt=prompt,
68
+ width=width,
69
+ height=height,
70
+ num_inference_steps=num_inference_steps,
71
+ guidance_scale=guidance_scale
72
+ ).images[0]
73
+
74
+ # Save image to a bytes buffer
75
+ img_io = io.BytesIO()
76
+ image.save(img_io, 'PNG')
77
+ img_io.seek(0)
78
+
79
+ return send_file(
80
+ img_io,
81
+ mimetype='image/png',
82
+ as_attachment=True,
83
+ download_name='generated_image.png'
84
+ )
85
+ except Exception as e:
86
+ logger.error(f"Image generation failed: {str(e)}")
87
+ return jsonify({'error': str(e)}), 500
88
+
89
+ if __name__ == '__main__':
90
+ port = int(os.getenv('PORT', 7860))
 
91
  app.run(host='0.0.0.0', port=port)