1use super::runtime::{MainCameraRenderTarget, asset_root, configure_shared_app};
2use super::{ViewerCommand, ViewerConfig, ViewerSessionHandle};
3use ak_core::{Structure, Trajectory};
4use bevy::app::{AppExit, ScheduleRunnerPlugin};
5use bevy::camera::RenderTarget;
6use bevy::image::TextureFormatPixelInfo;
7use bevy::prelude::*;
8use bevy::render::render_asset::RenderAssets;
9use bevy::render::render_graph::{self, NodeRunError, RenderGraph, RenderGraphContext};
10use bevy::render::render_resource::{
11 Buffer, BufferDescriptor, BufferUsages, CommandEncoderDescriptor, Extent3d, MapMode, PollType,
12 TexelCopyBufferInfo, TexelCopyBufferLayout, TextureFormat, TextureUsages,
13};
14use bevy::render::renderer::{RenderContext, RenderDevice, RenderQueue};
15use bevy::render::{Extract, ExtractSchedule, Render, RenderApp, RenderSystems};
16use bevy::transform::TransformSystems;
17use bevy::window::ExitCondition;
18use bevy::winit::WinitPlugin;
19use bevy_panorbit_camera::{PanOrbitCameraPlugin, PanOrbitCameraSystemSet};
20use crossbeam_channel::{Receiver, Sender};
21use std::panic::{AssertUnwindSafe, catch_unwind};
22use std::path::{Path, PathBuf};
23use std::sync::{
24 Arc, Mutex,
25 atomic::{AtomicBool, Ordering},
26 mpsc,
27};
28use std::thread;
29use std::time::Duration;
30
31const DEFAULT_PREROLL_FRAMES: u32 = 4;
32const DEFAULT_STABLE_FRAMES: u32 = 2;
33
34#[derive(Clone, Debug, PartialEq, Eq)]
35pub struct HeadlessRenderConfig {
36 pub path: PathBuf,
37 pub width: u32,
38 pub height: u32,
39 pub preroll_frames: u32,
40 pub stable_frames: u32,
41}
42
43impl HeadlessRenderConfig {
44 pub fn new(path: impl Into<PathBuf>, width: u32, height: u32) -> Self {
45 Self {
46 path: path.into(),
47 width,
48 height,
49 preroll_frames: DEFAULT_PREROLL_FRAMES,
50 stable_frames: DEFAULT_STABLE_FRAMES,
51 }
52 }
53}
54
55#[derive(Debug, Clone, PartialEq, Eq)]
56pub struct HeadlessRenderError {
57 message: String,
58}
59
60impl HeadlessRenderError {
61 fn new(message: impl Into<String>) -> Self {
62 Self {
63 message: message.into(),
64 }
65 }
66}
67
68impl std::fmt::Display for HeadlessRenderError {
69 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
70 f.write_str(&self.message)
71 }
72}
73
74impl std::error::Error for HeadlessRenderError {}
75
76impl From<std::io::Error> for HeadlessRenderError {
77 fn from(value: std::io::Error) -> Self {
78 Self::new(value.to_string())
79 }
80}
81
82#[derive(Resource, Clone)]
83struct CaptureSettings {
84 path: PathBuf,
85 stable_frames: u32,
86}
87
88#[derive(Resource, Clone, Copy)]
89struct HeadlessTargetSpec {
90 width: u32,
91 height: u32,
92}
93
94#[derive(Resource, Default)]
95struct CaptureState {
96 preroll_remaining: u32,
97 stable_frames: u32,
98 requested: bool,
99 request_delay_remaining: u32,
100}
101
102#[derive(Resource, Clone)]
103struct CaptureResult {
104 value: Arc<Mutex<Option<Result<(), HeadlessRenderError>>>>,
105}
106
107impl CaptureResult {
108 fn store(&self, result: Result<(), HeadlessRenderError>) {
109 if let Ok(mut slot) = self.value.lock() {
110 *slot = Some(result);
111 }
112 }
113
114 fn take(&self) -> Option<Result<(), HeadlessRenderError>> {
115 self.value.lock().ok().and_then(|mut slot| slot.take())
116 }
117}
118
119#[derive(Resource, Deref)]
120struct MainWorldReceiver(Receiver<Vec<u8>>);
121
122#[derive(Resource, Clone)]
123struct SessionDriverState {
124 complete: Arc<AtomicBool>,
125}
126
127#[derive(Resource, Deref)]
128struct RenderWorldSender(Sender<Vec<u8>>);
129
130#[derive(Component, Deref, DerefMut)]
131struct ImageToSave(Handle<Image>);
132
133#[derive(Clone, Default, Resource, Deref, DerefMut)]
134struct ImageCopiers(pub Vec<ImageCopier>);
135
136#[derive(Clone, Component)]
137struct ImageCopier {
138 buffer: Buffer,
139 src_image: Handle<Image>,
140}
141
142#[derive(bevy::render::render_graph::RenderLabel, Debug, PartialEq, Eq, Clone, Hash)]
143struct ImageCopy;
144
145#[derive(Default)]
146struct ImageCopyDriver;
147
148impl ImageCopier {
149 fn new(src_image: Handle<Image>, size: Extent3d, render_device: &RenderDevice) -> Self {
150 let padded_bytes_per_row = RenderDevice::align_copy_bytes_per_row(size.width as usize * 4);
151 let buffer = render_device.create_buffer(&BufferDescriptor {
152 label: None,
153 size: padded_bytes_per_row as u64 * size.height as u64,
154 usage: BufferUsages::MAP_READ | BufferUsages::COPY_DST,
155 mapped_at_creation: false,
156 });
157 Self { buffer, src_image }
158 }
159}
160
161impl render_graph::Node for ImageCopyDriver {
162 fn run(
163 &self,
164 _graph: &mut RenderGraphContext,
165 render_context: &mut RenderContext,
166 world: &World,
167 ) -> Result<(), NodeRunError> {
168 let image_copiers = world
169 .get_resource::<ImageCopiers>()
170 .expect("image copiers should be extracted");
171 let gpu_images = world
172 .get_resource::<RenderAssets<bevy::render::texture::GpuImage>>()
173 .expect("gpu images should exist");
174
175 for image_copier in image_copiers.iter() {
176 let src_image = gpu_images
177 .get(&image_copier.src_image)
178 .expect("render target image should exist");
179 let mut encoder = render_context
180 .render_device()
181 .create_command_encoder(&CommandEncoderDescriptor::default());
182
183 let block_dimensions = src_image.texture_format.block_dimensions();
184 let block_size = src_image.texture_format.block_copy_size(None).unwrap();
185 let padded_bytes_per_row = RenderDevice::align_copy_bytes_per_row(
186 (src_image.size.width as usize / block_dimensions.0 as usize) * block_size as usize,
187 );
188
189 encoder.copy_texture_to_buffer(
190 src_image.texture.as_image_copy(),
191 TexelCopyBufferInfo {
192 buffer: &image_copier.buffer,
193 layout: TexelCopyBufferLayout {
194 offset: 0,
195 bytes_per_row: Some(padded_bytes_per_row as u32),
196 rows_per_image: None,
197 },
198 },
199 src_image.size,
200 );
201
202 let render_queue = world.get_resource::<RenderQueue>().unwrap();
203 render_queue.submit(std::iter::once(encoder.finish()));
204 }
205
206 Ok(())
207 }
208}
209
210pub fn export_image(
211 trajectory: Trajectory,
212 config: ViewerConfig,
213 export: HeadlessRenderConfig,
214) -> Result<(), HeadlessRenderError> {
215 run_headless(trajectory, config, export, None, None)
216}
217
218pub fn export_structure_image(
219 structure: Structure,
220 config: ViewerConfig,
221 export: HeadlessRenderConfig,
222) -> Result<(), HeadlessRenderError> {
223 export_image(Trajectory::new(vec![structure]), config, export)
224}
225
226pub fn export_image_with_session<F>(
227 trajectory: Trajectory,
228 config: ViewerConfig,
229 export: HeadlessRenderConfig,
230 driver: F,
231) -> Result<(), HeadlessRenderError>
232where
233 F: FnOnce(ViewerSessionHandle) + Send + 'static,
234{
235 let (sender, receiver) = mpsc::channel();
236 let handle = ViewerSessionHandle::new(sender);
237 let driver_complete = Arc::new(AtomicBool::new(false));
238 let driver_complete_thread = driver_complete.clone();
239 let driver_thread = thread::Builder::new()
240 .name("ak-headless-session-driver".to_string())
241 .spawn(move || {
242 driver(handle);
243 driver_complete_thread.store(true, Ordering::Release);
244 })
245 .map_err(|err| HeadlessRenderError::new(format!("failed to spawn driver: {err}")))?;
246
247 let result = run_headless(
248 trajectory,
249 config,
250 export,
251 Some(receiver),
252 Some(driver_complete),
253 );
254 let _ = driver_thread.join();
255 result
256}
257
258pub fn export_prepared_image(
259 trajectory: Trajectory,
260 config: ViewerConfig,
261 export: HeadlessRenderConfig,
262 receiver: mpsc::Receiver<ViewerCommand>,
263) -> Result<(), HeadlessRenderError> {
264 run_headless(trajectory, config, export, Some(receiver), None)
265}
266
267fn run_headless(
268 trajectory: Trajectory,
269 mut config: ViewerConfig,
270 export: HeadlessRenderConfig,
271 receiver: Option<mpsc::Receiver<ViewerCommand>>,
272 session_driver_complete: Option<Arc<AtomicBool>>,
273) -> Result<(), HeadlessRenderError> {
274 config.render.show_ui = false;
275 config.render.show_orientation_widget = false;
276
277 let result = CaptureResult {
278 value: Arc::new(Mutex::new(None)),
279 };
280 let render_result = catch_unwind(AssertUnwindSafe(|| -> Result<(), HeadlessRenderError> {
281 let mut app = build_app(
282 trajectory,
283 config,
284 export,
285 result.clone(),
286 receiver,
287 session_driver_complete,
288 )?;
289 app.run();
290 result.take().unwrap_or_else(|| {
291 Err(HeadlessRenderError::new(
292 "headless render exited without producing a capture result",
293 ))
294 })
295 }));
296
297 match render_result {
298 Ok(result) => result,
299 Err(payload) => {
300 let message = if let Some(message) = payload.downcast_ref::<String>() {
301 message.clone()
302 } else if let Some(message) = payload.downcast_ref::<&'static str>() {
303 (*message).to_string()
304 } else {
305 "headless render panicked".to_string()
306 };
307 Err(HeadlessRenderError::new(message))
308 }
309 }
310}
311
312fn build_app(
313 trajectory: Trajectory,
314 config: ViewerConfig,
315 export: HeadlessRenderConfig,
316 result: CaptureResult,
317 receiver: Option<mpsc::Receiver<ViewerCommand>>,
318 session_driver_complete: Option<Arc<AtomicBool>>,
319) -> Result<App, HeadlessRenderError> {
320 if export.width == 0 || export.height == 0 {
321 return Err(HeadlessRenderError::new(
322 "headless render image size must be non-zero",
323 ));
324 }
325
326 let mut app = App::new();
327 app.add_plugins(
328 DefaultPlugins
329 .build()
330 .disable::<bevy::log::LogPlugin>()
331 .set(ImagePlugin::default_nearest())
332 .set(AssetPlugin {
333 file_path: asset_root(),
334 ..default()
335 })
336 .set(WindowPlugin {
337 primary_window: None,
338 exit_condition: ExitCondition::DontExit,
339 ..default()
340 })
341 .disable::<WinitPlugin>(),
342 )
343 .add_plugins(PanOrbitCameraPlugin)
344 .add_plugins(ScheduleRunnerPlugin::run_loop(Duration::from_secs_f64(
345 1.0 / 60.0,
346 )));
347
348 app.insert_resource(HeadlessTargetSpec {
349 width: export.width,
350 height: export.height,
351 })
352 .insert_resource(CaptureSettings {
353 path: export.path,
354 stable_frames: export.stable_frames.max(1),
355 })
356 .insert_resource(CaptureState {
357 preroll_remaining: export.preroll_frames,
358 stable_frames: 0,
359 requested: false,
360 request_delay_remaining: 0,
361 })
362 .insert_resource(result);
363
364 if let Some(complete) = session_driver_complete {
365 app.insert_resource(SessionDriverState { complete });
366 }
367
368 configure_shared_app(
369 &mut app,
370 trajectory,
371 config,
372 receiver,
373 Arc::new(Mutex::new(super::session::ViewerSnapshot::default())),
374 );
375 app.add_systems(
376 Startup,
377 setup_headless_target.before(super::systems::setup_camera),
378 );
379 app.add_systems(
380 PostUpdate,
381 queue_capture
382 .after(PanOrbitCameraSystemSet)
383 .after(TransformSystems::Propagate),
384 );
385 setup_image_copy(&mut app);
386 app.add_systems(PostUpdate, save_capture.after(queue_capture));
387
388 Ok(app)
389}
390
391fn setup_headless_target(
392 mut commands: Commands,
393 mut images: ResMut<Assets<Image>>,
394 render_device: Res<RenderDevice>,
395 spec: Res<HeadlessTargetSpec>,
396) {
397 let size = Extent3d {
398 width: spec.width,
399 height: spec.height,
400 ..default()
401 };
402 let mut render_target_image =
403 Image::new_target_texture(size.width, size.height, TextureFormat::bevy_default(), None);
404 render_target_image.texture_descriptor.usage |= TextureUsages::COPY_SRC;
405 let render_target_image_handle = images.add(render_target_image);
406
407 let cpu_image =
408 Image::new_target_texture(size.width, size.height, TextureFormat::bevy_default(), None);
409 let cpu_image_handle = images.add(cpu_image);
410
411 commands.spawn(ImageCopier::new(
412 render_target_image_handle.clone(),
413 size,
414 &render_device,
415 ));
416 commands.spawn(ImageToSave(cpu_image_handle));
417 commands.insert_resource(MainCameraRenderTarget(RenderTarget::Image(
418 render_target_image_handle.into(),
419 )));
420}
421
422fn setup_image_copy(app: &mut App) {
423 let (sender, receiver) = crossbeam_channel::unbounded();
424 app.insert_resource(MainWorldReceiver(receiver));
425
426 let render_app = app.sub_app_mut(RenderApp);
427 let mut graph = render_app.world_mut().resource_mut::<RenderGraph>();
428 graph.add_node(ImageCopy, ImageCopyDriver);
429 graph.add_node_edge(bevy::render::graph::CameraDriverLabel, ImageCopy);
430
431 render_app
432 .insert_resource(RenderWorldSender(sender))
433 .add_systems(ExtractSchedule, image_copy_extract)
434 .add_systems(
435 Render,
436 receive_image_from_buffer.after(RenderSystems::Render),
437 );
438}
439
440fn image_copy_extract(mut commands: Commands, image_copiers: Extract<Query<&ImageCopier>>) {
441 commands.insert_resource(ImageCopiers(image_copiers.iter().cloned().collect()));
442}
443
444fn receive_image_from_buffer(
445 image_copiers: Res<ImageCopiers>,
446 render_device: Res<RenderDevice>,
447 sender: Res<RenderWorldSender>,
448) {
449 for image_copier in image_copiers.iter() {
450 let buffer_slice = image_copier.buffer.slice(..);
451 let (buffer_sender, buffer_receiver) = crossbeam_channel::bounded(1);
452
453 buffer_slice.map_async(MapMode::Read, move |result| {
454 let _ = buffer_sender.send(result);
455 });
456
457 render_device
458 .poll(PollType::wait_indefinitely())
459 .expect("failed to poll render device");
460 buffer_receiver
461 .recv()
462 .expect("failed to receive map_async result")
463 .expect("failed to map image buffer");
464
465 let _ = sender.send(buffer_slice.get_mapped_range().to_vec());
466 image_copier.buffer.unmap();
467 }
468}
469
470fn queue_capture(
471 mut state: ResMut<CaptureState>,
472 settings: Res<CaptureSettings>,
473 viewer: Res<super::ViewerState>,
474 camera: Res<super::CameraState>,
475 session_driver_state: Option<Res<SessionDriverState>>,
476) {
477 if state.requested {
478 return;
479 }
480
481 if let Some(session_driver_state) = session_driver_state
482 && !session_driver_state.complete.load(Ordering::Acquire)
483 {
484 state.stable_frames = 0;
485 return;
486 }
487
488 if viewer.needs_render || camera.needs_apply {
489 state.stable_frames = 0;
490 return;
491 }
492
493 if state.preroll_remaining > 0 {
494 state.preroll_remaining -= 1;
495 return;
496 }
497
498 state.stable_frames += 1;
499 if state.stable_frames >= settings.stable_frames {
500 state.requested = true;
501 state.request_delay_remaining = 1;
502 }
503}
504
505fn save_capture(
506 images_to_save: Query<&ImageToSave>,
507 receiver: Res<MainWorldReceiver>,
508 mut images: ResMut<Assets<Image>>,
509 settings: Res<CaptureSettings>,
510 mut state: ResMut<CaptureState>,
511 result: Res<CaptureResult>,
512 mut app_exit: MessageWriter<AppExit>,
513) {
514 if !state.requested {
515 return;
516 }
517 if state.request_delay_remaining > 0 {
518 state.request_delay_remaining -= 1;
519 return;
520 }
521
522 let mut image_data = Vec::new();
523 while let Ok(data) = receiver.try_recv() {
524 image_data = data;
525 }
526 if image_data.is_empty() {
527 result.store(Err(HeadlessRenderError::new(
528 "headless render did not receive any image data",
529 )));
530 app_exit.write(AppExit::Success);
531 return;
532 }
533
534 let save_result = save_image_data(&settings.path, images_to_save, &mut images, &image_data);
535 result.store(save_result.map_err(HeadlessRenderError::from));
536 state.requested = false;
537 app_exit.write(AppExit::Success);
538}
539
540fn save_image_data(
541 path: &Path,
542 images_to_save: Query<&ImageToSave>,
543 images: &mut Assets<Image>,
544 image_data: &[u8],
545) -> std::io::Result<()> {
546 let Some(image) = images_to_save.iter().next() else {
547 return Err(std::io::Error::other("missing CPU image target"));
548 };
549 let img_bytes = images
550 .get_mut(image.id())
551 .ok_or_else(|| std::io::Error::other("missing saved image"))?;
552 let row_bytes =
553 img_bytes.width() as usize * img_bytes.texture_descriptor.format.pixel_size().unwrap();
554 let aligned_row_bytes = RenderDevice::align_copy_bytes_per_row(row_bytes);
555 if row_bytes == aligned_row_bytes {
556 img_bytes
557 .data
558 .as_mut()
559 .unwrap()
560 .clone_from_slice(image_data);
561 } else {
562 img_bytes.data = Some(
563 image_data
564 .chunks(aligned_row_bytes)
565 .take(img_bytes.height() as usize)
566 .flat_map(|row| &row[..row_bytes.min(row.len())])
567 .copied()
568 .collect(),
569 );
570 }
571
572 if let Some(parent) = path.parent()
573 && !parent.as_os_str().is_empty()
574 {
575 std::fs::create_dir_all(parent)?;
576 }
577
578 let image = img_bytes
579 .clone()
580 .try_into_dynamic()
581 .map_err(|err| std::io::Error::other(err.to_string()))?;
582 image
583 .to_rgba8()
584 .save(path)
585 .map_err(|err| std::io::Error::other(err.to_string()))
586}
587
588#[cfg(test)]
589mod tests {
590 use super::*;
591 use crate::{BondList, Face, FaceList};
592 use image::GenericImageView;
593 use std::sync::{
594 Mutex,
595 atomic::{AtomicU64, Ordering},
596 };
597
598 static UNIQUE_ID: AtomicU64 = AtomicU64::new(0);
599 static HEADLESS_TEST_LOCK: Mutex<()> = Mutex::new(());
600
601 fn fixture_structure() -> Structure {
602 Structure::new(
603 vec![
604 [0.0, 0.0, 0.0],
605 [0.0, 0.0, 0.74],
606 [1.0, 0.0, 0.0],
607 [0.0, 1.0, 0.0],
608 ],
609 vec![8, 1, 1, 1],
610 [[8.0, 0.0, 0.0], [0.0, 8.0, 0.0], [0.0, 0.0, 8.0]],
611 [false, false, false],
612 )
613 }
614
615 fn temp_png(name: &str) -> PathBuf {
616 let id = UNIQUE_ID.fetch_add(1, Ordering::Relaxed);
617 std::env::temp_dir().join(format!("ak-headless-{name}-{id}.png"))
618 }
619
620 fn test_render_config(path: &Path) -> HeadlessRenderConfig {
621 let mut config = HeadlessRenderConfig::new(path, 320, 240);
622 if std::env::var_os("CI").is_some() {
623 config.preroll_frames = 48;
624 config.stable_frames = 12;
625 }
626 config
627 }
628
629 fn should_run_headless_render_tests() -> bool {
630 if std::env::var_os("ATOMIC_KERNELS_RUN_RUST_HEADLESS_TESTS").is_some() {
631 return true;
632 }
633 std::env::var_os("CI").is_none()
634 }
635
636 fn image_has_content(path: &Path, width: u32, height: u32) -> bool {
637 let image = image::open(path).expect("saved image should be readable");
638 assert_eq!(image.dimensions(), (width, height));
639 let rgba = image.to_rgba8();
640 let mut unique = std::collections::BTreeSet::new();
641 for pixel in rgba.pixels() {
642 unique.insert(pixel.0);
643 if unique.len() > 8 {
644 break;
645 }
646 }
647 unique.len() > 1
648 }
649
650 fn assert_render_succeeds<F>(name: &str, mut render_once: F)
651 where
652 F: FnMut(&Path) -> Result<(), HeadlessRenderError>,
653 {
654 let attempts = if std::env::var_os("CI").is_some() {
655 5
656 } else {
657 1
658 };
659 for attempt in 0..attempts {
660 let path = temp_png(name);
661 match render_once(&path) {
662 Ok(()) => {
663 if image_has_content(&path, 320, 240) {
664 let _ = std::fs::remove_file(path);
665 return;
666 }
667 }
668 Err(err) if err.to_string().contains("Unable to find a GPU") => return,
669 Err(err) => panic!("headless export should succeed: {err}"),
670 }
671 let _ = std::fs::remove_file(&path);
672 if attempt + 1 == attempts {
673 panic!("rendered image should not be a flat color");
674 }
675 }
676 }
677
678 #[test]
679 fn exports_default_scene_to_png() {
680 if !should_run_headless_render_tests() {
681 return;
682 }
683 let _guard = HEADLESS_TEST_LOCK.lock().unwrap();
684 let config = ViewerConfig::default();
685 assert_render_succeeds("default", |path| {
686 export_structure_image(
687 fixture_structure(),
688 config.clone(),
689 test_render_config(path),
690 )
691 });
692 }
693
694 #[test]
695 fn exports_scripted_scene_with_shared_session_commands() {
696 if !should_run_headless_render_tests() {
697 return;
698 }
699 let _guard = HEADLESS_TEST_LOCK.lock().unwrap();
700 let mut config = ViewerConfig::default();
701 config.render.show_axes = false;
702 assert_render_succeeds("scripted", |path| {
703 export_image_with_session(
704 Trajectory::new(vec![fixture_structure()]),
705 config.clone(),
706 test_render_config(path),
707 |session| {
708 session
709 .set_bonds(BondList::new([(0, 1), (0, 2), (0, 3)]), Some(0))
710 .unwrap();
711 session
712 .set_faces(
713 FaceList::new([Face::new([1, 2, 3], [0.2, 0.6, 0.9, 0.35]).unwrap()]),
714 Some(0),
715 )
716 .unwrap();
717 session
718 .set_render_style(
719 super::super::RenderStyle::BallAndStick(
720 super::super::BallAndStickStyle {
721 atom_scale: 0.5,
722 bond_radius: 0.06,
723 bond_color: [0.5, 0.5, 0.5, 1.0],
724 bond_scope: super::super::BondScope::TouchSelection,
725 },
726 ),
727 vec![true, true, true, true],
728 Some(0),
729 false,
730 )
731 .unwrap();
732 session.frame_all().unwrap();
733 },
734 )
735 });
736 }
737}