ak_vis/viewer/
headless.rs

1use super::runtime::{MainCameraRenderTarget, asset_root, configure_shared_app};
2use super::{ViewerCommand, ViewerConfig, ViewerSessionHandle};
3use ak_core::{Structure, Trajectory};
4use bevy::app::{AppExit, ScheduleRunnerPlugin};
5use bevy::camera::RenderTarget;
6use bevy::image::TextureFormatPixelInfo;
7use bevy::prelude::*;
8use bevy::render::render_asset::RenderAssets;
9use bevy::render::render_graph::{self, NodeRunError, RenderGraph, RenderGraphContext};
10use bevy::render::render_resource::{
11    Buffer, BufferDescriptor, BufferUsages, CommandEncoderDescriptor, Extent3d, MapMode, PollType,
12    TexelCopyBufferInfo, TexelCopyBufferLayout, TextureFormat, TextureUsages,
13};
14use bevy::render::renderer::{RenderContext, RenderDevice, RenderQueue};
15use bevy::render::{Extract, ExtractSchedule, Render, RenderApp, RenderSystems};
16use bevy::transform::TransformSystems;
17use bevy::window::ExitCondition;
18use bevy::winit::WinitPlugin;
19use bevy_panorbit_camera::{PanOrbitCameraPlugin, PanOrbitCameraSystemSet};
20use crossbeam_channel::{Receiver, Sender};
21use std::panic::{AssertUnwindSafe, catch_unwind};
22use std::path::{Path, PathBuf};
23use std::sync::{
24    Arc, Mutex,
25    atomic::{AtomicBool, Ordering},
26    mpsc,
27};
28use std::thread;
29use std::time::Duration;
30
31const DEFAULT_PREROLL_FRAMES: u32 = 4;
32const DEFAULT_STABLE_FRAMES: u32 = 2;
33
34#[derive(Clone, Debug, PartialEq, Eq)]
35pub struct HeadlessRenderConfig {
36    pub path: PathBuf,
37    pub width: u32,
38    pub height: u32,
39    pub preroll_frames: u32,
40    pub stable_frames: u32,
41}
42
43impl HeadlessRenderConfig {
44    pub fn new(path: impl Into<PathBuf>, width: u32, height: u32) -> Self {
45        Self {
46            path: path.into(),
47            width,
48            height,
49            preroll_frames: DEFAULT_PREROLL_FRAMES,
50            stable_frames: DEFAULT_STABLE_FRAMES,
51        }
52    }
53}
54
55#[derive(Debug, Clone, PartialEq, Eq)]
56pub struct HeadlessRenderError {
57    message: String,
58}
59
60impl HeadlessRenderError {
61    fn new(message: impl Into<String>) -> Self {
62        Self {
63            message: message.into(),
64        }
65    }
66}
67
68impl std::fmt::Display for HeadlessRenderError {
69    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
70        f.write_str(&self.message)
71    }
72}
73
74impl std::error::Error for HeadlessRenderError {}
75
76impl From<std::io::Error> for HeadlessRenderError {
77    fn from(value: std::io::Error) -> Self {
78        Self::new(value.to_string())
79    }
80}
81
82#[derive(Resource, Clone)]
83struct CaptureSettings {
84    path: PathBuf,
85    stable_frames: u32,
86}
87
88#[derive(Resource, Clone, Copy)]
89struct HeadlessTargetSpec {
90    width: u32,
91    height: u32,
92}
93
94#[derive(Resource, Default)]
95struct CaptureState {
96    preroll_remaining: u32,
97    stable_frames: u32,
98    requested: bool,
99    request_delay_remaining: u32,
100}
101
102#[derive(Resource, Clone)]
103struct CaptureResult {
104    value: Arc<Mutex<Option<Result<(), HeadlessRenderError>>>>,
105}
106
107impl CaptureResult {
108    fn store(&self, result: Result<(), HeadlessRenderError>) {
109        if let Ok(mut slot) = self.value.lock() {
110            *slot = Some(result);
111        }
112    }
113
114    fn take(&self) -> Option<Result<(), HeadlessRenderError>> {
115        self.value.lock().ok().and_then(|mut slot| slot.take())
116    }
117}
118
119#[derive(Resource, Deref)]
120struct MainWorldReceiver(Receiver<Vec<u8>>);
121
122#[derive(Resource, Clone)]
123struct SessionDriverState {
124    complete: Arc<AtomicBool>,
125}
126
127#[derive(Resource, Deref)]
128struct RenderWorldSender(Sender<Vec<u8>>);
129
130#[derive(Component, Deref, DerefMut)]
131struct ImageToSave(Handle<Image>);
132
133#[derive(Clone, Default, Resource, Deref, DerefMut)]
134struct ImageCopiers(pub Vec<ImageCopier>);
135
136#[derive(Clone, Component)]
137struct ImageCopier {
138    buffer: Buffer,
139    src_image: Handle<Image>,
140}
141
142#[derive(bevy::render::render_graph::RenderLabel, Debug, PartialEq, Eq, Clone, Hash)]
143struct ImageCopy;
144
145#[derive(Default)]
146struct ImageCopyDriver;
147
148impl ImageCopier {
149    fn new(src_image: Handle<Image>, size: Extent3d, render_device: &RenderDevice) -> Self {
150        let padded_bytes_per_row = RenderDevice::align_copy_bytes_per_row(size.width as usize * 4);
151        let buffer = render_device.create_buffer(&BufferDescriptor {
152            label: None,
153            size: padded_bytes_per_row as u64 * size.height as u64,
154            usage: BufferUsages::MAP_READ | BufferUsages::COPY_DST,
155            mapped_at_creation: false,
156        });
157        Self { buffer, src_image }
158    }
159}
160
161impl render_graph::Node for ImageCopyDriver {
162    fn run(
163        &self,
164        _graph: &mut RenderGraphContext,
165        render_context: &mut RenderContext,
166        world: &World,
167    ) -> Result<(), NodeRunError> {
168        let image_copiers = world
169            .get_resource::<ImageCopiers>()
170            .expect("image copiers should be extracted");
171        let gpu_images = world
172            .get_resource::<RenderAssets<bevy::render::texture::GpuImage>>()
173            .expect("gpu images should exist");
174
175        for image_copier in image_copiers.iter() {
176            let src_image = gpu_images
177                .get(&image_copier.src_image)
178                .expect("render target image should exist");
179            let mut encoder = render_context
180                .render_device()
181                .create_command_encoder(&CommandEncoderDescriptor::default());
182
183            let block_dimensions = src_image.texture_format.block_dimensions();
184            let block_size = src_image.texture_format.block_copy_size(None).unwrap();
185            let padded_bytes_per_row = RenderDevice::align_copy_bytes_per_row(
186                (src_image.size.width as usize / block_dimensions.0 as usize) * block_size as usize,
187            );
188
189            encoder.copy_texture_to_buffer(
190                src_image.texture.as_image_copy(),
191                TexelCopyBufferInfo {
192                    buffer: &image_copier.buffer,
193                    layout: TexelCopyBufferLayout {
194                        offset: 0,
195                        bytes_per_row: Some(padded_bytes_per_row as u32),
196                        rows_per_image: None,
197                    },
198                },
199                src_image.size,
200            );
201
202            let render_queue = world.get_resource::<RenderQueue>().unwrap();
203            render_queue.submit(std::iter::once(encoder.finish()));
204        }
205
206        Ok(())
207    }
208}
209
210pub fn export_image(
211    trajectory: Trajectory,
212    config: ViewerConfig,
213    export: HeadlessRenderConfig,
214) -> Result<(), HeadlessRenderError> {
215    run_headless(trajectory, config, export, None, None)
216}
217
218pub fn export_structure_image(
219    structure: Structure,
220    config: ViewerConfig,
221    export: HeadlessRenderConfig,
222) -> Result<(), HeadlessRenderError> {
223    export_image(Trajectory::new(vec![structure]), config, export)
224}
225
226pub fn export_image_with_session<F>(
227    trajectory: Trajectory,
228    config: ViewerConfig,
229    export: HeadlessRenderConfig,
230    driver: F,
231) -> Result<(), HeadlessRenderError>
232where
233    F: FnOnce(ViewerSessionHandle) + Send + 'static,
234{
235    let (sender, receiver) = mpsc::channel();
236    let handle = ViewerSessionHandle::new(sender);
237    let driver_complete = Arc::new(AtomicBool::new(false));
238    let driver_complete_thread = driver_complete.clone();
239    let driver_thread = thread::Builder::new()
240        .name("ak-headless-session-driver".to_string())
241        .spawn(move || {
242            driver(handle);
243            driver_complete_thread.store(true, Ordering::Release);
244        })
245        .map_err(|err| HeadlessRenderError::new(format!("failed to spawn driver: {err}")))?;
246
247    let result = run_headless(
248        trajectory,
249        config,
250        export,
251        Some(receiver),
252        Some(driver_complete),
253    );
254    let _ = driver_thread.join();
255    result
256}
257
258pub fn export_prepared_image(
259    trajectory: Trajectory,
260    config: ViewerConfig,
261    export: HeadlessRenderConfig,
262    receiver: mpsc::Receiver<ViewerCommand>,
263) -> Result<(), HeadlessRenderError> {
264    run_headless(trajectory, config, export, Some(receiver), None)
265}
266
267fn run_headless(
268    trajectory: Trajectory,
269    mut config: ViewerConfig,
270    export: HeadlessRenderConfig,
271    receiver: Option<mpsc::Receiver<ViewerCommand>>,
272    session_driver_complete: Option<Arc<AtomicBool>>,
273) -> Result<(), HeadlessRenderError> {
274    config.render.show_ui = false;
275    config.render.show_orientation_widget = false;
276
277    let result = CaptureResult {
278        value: Arc::new(Mutex::new(None)),
279    };
280    let render_result = catch_unwind(AssertUnwindSafe(|| -> Result<(), HeadlessRenderError> {
281        let mut app = build_app(
282            trajectory,
283            config,
284            export,
285            result.clone(),
286            receiver,
287            session_driver_complete,
288        )?;
289        app.run();
290        result.take().unwrap_or_else(|| {
291            Err(HeadlessRenderError::new(
292                "headless render exited without producing a capture result",
293            ))
294        })
295    }));
296
297    match render_result {
298        Ok(result) => result,
299        Err(payload) => {
300            let message = if let Some(message) = payload.downcast_ref::<String>() {
301                message.clone()
302            } else if let Some(message) = payload.downcast_ref::<&'static str>() {
303                (*message).to_string()
304            } else {
305                "headless render panicked".to_string()
306            };
307            Err(HeadlessRenderError::new(message))
308        }
309    }
310}
311
312fn build_app(
313    trajectory: Trajectory,
314    config: ViewerConfig,
315    export: HeadlessRenderConfig,
316    result: CaptureResult,
317    receiver: Option<mpsc::Receiver<ViewerCommand>>,
318    session_driver_complete: Option<Arc<AtomicBool>>,
319) -> Result<App, HeadlessRenderError> {
320    if export.width == 0 || export.height == 0 {
321        return Err(HeadlessRenderError::new(
322            "headless render image size must be non-zero",
323        ));
324    }
325
326    let mut app = App::new();
327    app.add_plugins(
328        DefaultPlugins
329            .build()
330            .disable::<bevy::log::LogPlugin>()
331            .set(ImagePlugin::default_nearest())
332            .set(AssetPlugin {
333                file_path: asset_root(),
334                ..default()
335            })
336            .set(WindowPlugin {
337                primary_window: None,
338                exit_condition: ExitCondition::DontExit,
339                ..default()
340            })
341            .disable::<WinitPlugin>(),
342    )
343    .add_plugins(PanOrbitCameraPlugin)
344    .add_plugins(ScheduleRunnerPlugin::run_loop(Duration::from_secs_f64(
345        1.0 / 60.0,
346    )));
347
348    app.insert_resource(HeadlessTargetSpec {
349        width: export.width,
350        height: export.height,
351    })
352    .insert_resource(CaptureSettings {
353        path: export.path,
354        stable_frames: export.stable_frames.max(1),
355    })
356    .insert_resource(CaptureState {
357        preroll_remaining: export.preroll_frames,
358        stable_frames: 0,
359        requested: false,
360        request_delay_remaining: 0,
361    })
362    .insert_resource(result);
363
364    if let Some(complete) = session_driver_complete {
365        app.insert_resource(SessionDriverState { complete });
366    }
367
368    configure_shared_app(
369        &mut app,
370        trajectory,
371        config,
372        receiver,
373        Arc::new(Mutex::new(super::session::ViewerSnapshot::default())),
374    );
375    app.add_systems(
376        Startup,
377        setup_headless_target.before(super::systems::setup_camera),
378    );
379    app.add_systems(
380        PostUpdate,
381        queue_capture
382            .after(PanOrbitCameraSystemSet)
383            .after(TransformSystems::Propagate),
384    );
385    setup_image_copy(&mut app);
386    app.add_systems(PostUpdate, save_capture.after(queue_capture));
387
388    Ok(app)
389}
390
391fn setup_headless_target(
392    mut commands: Commands,
393    mut images: ResMut<Assets<Image>>,
394    render_device: Res<RenderDevice>,
395    spec: Res<HeadlessTargetSpec>,
396) {
397    let size = Extent3d {
398        width: spec.width,
399        height: spec.height,
400        ..default()
401    };
402    let mut render_target_image =
403        Image::new_target_texture(size.width, size.height, TextureFormat::bevy_default(), None);
404    render_target_image.texture_descriptor.usage |= TextureUsages::COPY_SRC;
405    let render_target_image_handle = images.add(render_target_image);
406
407    let cpu_image =
408        Image::new_target_texture(size.width, size.height, TextureFormat::bevy_default(), None);
409    let cpu_image_handle = images.add(cpu_image);
410
411    commands.spawn(ImageCopier::new(
412        render_target_image_handle.clone(),
413        size,
414        &render_device,
415    ));
416    commands.spawn(ImageToSave(cpu_image_handle));
417    commands.insert_resource(MainCameraRenderTarget(RenderTarget::Image(
418        render_target_image_handle.into(),
419    )));
420}
421
422fn setup_image_copy(app: &mut App) {
423    let (sender, receiver) = crossbeam_channel::unbounded();
424    app.insert_resource(MainWorldReceiver(receiver));
425
426    let render_app = app.sub_app_mut(RenderApp);
427    let mut graph = render_app.world_mut().resource_mut::<RenderGraph>();
428    graph.add_node(ImageCopy, ImageCopyDriver);
429    graph.add_node_edge(bevy::render::graph::CameraDriverLabel, ImageCopy);
430
431    render_app
432        .insert_resource(RenderWorldSender(sender))
433        .add_systems(ExtractSchedule, image_copy_extract)
434        .add_systems(
435            Render,
436            receive_image_from_buffer.after(RenderSystems::Render),
437        );
438}
439
440fn image_copy_extract(mut commands: Commands, image_copiers: Extract<Query<&ImageCopier>>) {
441    commands.insert_resource(ImageCopiers(image_copiers.iter().cloned().collect()));
442}
443
444fn receive_image_from_buffer(
445    image_copiers: Res<ImageCopiers>,
446    render_device: Res<RenderDevice>,
447    sender: Res<RenderWorldSender>,
448) {
449    for image_copier in image_copiers.iter() {
450        let buffer_slice = image_copier.buffer.slice(..);
451        let (buffer_sender, buffer_receiver) = crossbeam_channel::bounded(1);
452
453        buffer_slice.map_async(MapMode::Read, move |result| {
454            let _ = buffer_sender.send(result);
455        });
456
457        render_device
458            .poll(PollType::wait_indefinitely())
459            .expect("failed to poll render device");
460        buffer_receiver
461            .recv()
462            .expect("failed to receive map_async result")
463            .expect("failed to map image buffer");
464
465        let _ = sender.send(buffer_slice.get_mapped_range().to_vec());
466        image_copier.buffer.unmap();
467    }
468}
469
470fn queue_capture(
471    mut state: ResMut<CaptureState>,
472    settings: Res<CaptureSettings>,
473    viewer: Res<super::ViewerState>,
474    camera: Res<super::CameraState>,
475    session_driver_state: Option<Res<SessionDriverState>>,
476) {
477    if state.requested {
478        return;
479    }
480
481    if let Some(session_driver_state) = session_driver_state
482        && !session_driver_state.complete.load(Ordering::Acquire)
483    {
484        state.stable_frames = 0;
485        return;
486    }
487
488    if viewer.needs_render || camera.needs_apply {
489        state.stable_frames = 0;
490        return;
491    }
492
493    if state.preroll_remaining > 0 {
494        state.preroll_remaining -= 1;
495        return;
496    }
497
498    state.stable_frames += 1;
499    if state.stable_frames >= settings.stable_frames {
500        state.requested = true;
501        state.request_delay_remaining = 1;
502    }
503}
504
505fn save_capture(
506    images_to_save: Query<&ImageToSave>,
507    receiver: Res<MainWorldReceiver>,
508    mut images: ResMut<Assets<Image>>,
509    settings: Res<CaptureSettings>,
510    mut state: ResMut<CaptureState>,
511    result: Res<CaptureResult>,
512    mut app_exit: MessageWriter<AppExit>,
513) {
514    if !state.requested {
515        return;
516    }
517    if state.request_delay_remaining > 0 {
518        state.request_delay_remaining -= 1;
519        return;
520    }
521
522    let mut image_data = Vec::new();
523    while let Ok(data) = receiver.try_recv() {
524        image_data = data;
525    }
526    if image_data.is_empty() {
527        result.store(Err(HeadlessRenderError::new(
528            "headless render did not receive any image data",
529        )));
530        app_exit.write(AppExit::Success);
531        return;
532    }
533
534    let save_result = save_image_data(&settings.path, images_to_save, &mut images, &image_data);
535    result.store(save_result.map_err(HeadlessRenderError::from));
536    state.requested = false;
537    app_exit.write(AppExit::Success);
538}
539
540fn save_image_data(
541    path: &Path,
542    images_to_save: Query<&ImageToSave>,
543    images: &mut Assets<Image>,
544    image_data: &[u8],
545) -> std::io::Result<()> {
546    let Some(image) = images_to_save.iter().next() else {
547        return Err(std::io::Error::other("missing CPU image target"));
548    };
549    let img_bytes = images
550        .get_mut(image.id())
551        .ok_or_else(|| std::io::Error::other("missing saved image"))?;
552    let row_bytes =
553        img_bytes.width() as usize * img_bytes.texture_descriptor.format.pixel_size().unwrap();
554    let aligned_row_bytes = RenderDevice::align_copy_bytes_per_row(row_bytes);
555    if row_bytes == aligned_row_bytes {
556        img_bytes
557            .data
558            .as_mut()
559            .unwrap()
560            .clone_from_slice(image_data);
561    } else {
562        img_bytes.data = Some(
563            image_data
564                .chunks(aligned_row_bytes)
565                .take(img_bytes.height() as usize)
566                .flat_map(|row| &row[..row_bytes.min(row.len())])
567                .copied()
568                .collect(),
569        );
570    }
571
572    if let Some(parent) = path.parent()
573        && !parent.as_os_str().is_empty()
574    {
575        std::fs::create_dir_all(parent)?;
576    }
577
578    let image = img_bytes
579        .clone()
580        .try_into_dynamic()
581        .map_err(|err| std::io::Error::other(err.to_string()))?;
582    image
583        .to_rgba8()
584        .save(path)
585        .map_err(|err| std::io::Error::other(err.to_string()))
586}
587
588#[cfg(test)]
589mod tests {
590    use super::*;
591    use crate::{BondList, Face, FaceList};
592    use image::GenericImageView;
593    use std::sync::{
594        Mutex,
595        atomic::{AtomicU64, Ordering},
596    };
597
598    static UNIQUE_ID: AtomicU64 = AtomicU64::new(0);
599    static HEADLESS_TEST_LOCK: Mutex<()> = Mutex::new(());
600
601    fn fixture_structure() -> Structure {
602        Structure::new(
603            vec![
604                [0.0, 0.0, 0.0],
605                [0.0, 0.0, 0.74],
606                [1.0, 0.0, 0.0],
607                [0.0, 1.0, 0.0],
608            ],
609            vec![8, 1, 1, 1],
610            [[8.0, 0.0, 0.0], [0.0, 8.0, 0.0], [0.0, 0.0, 8.0]],
611            [false, false, false],
612        )
613    }
614
615    fn temp_png(name: &str) -> PathBuf {
616        let id = UNIQUE_ID.fetch_add(1, Ordering::Relaxed);
617        std::env::temp_dir().join(format!("ak-headless-{name}-{id}.png"))
618    }
619
620    fn test_render_config(path: &Path) -> HeadlessRenderConfig {
621        let mut config = HeadlessRenderConfig::new(path, 320, 240);
622        if std::env::var_os("CI").is_some() {
623            config.preroll_frames = 48;
624            config.stable_frames = 12;
625        }
626        config
627    }
628
629    fn should_run_headless_render_tests() -> bool {
630        if std::env::var_os("ATOMIC_KERNELS_RUN_RUST_HEADLESS_TESTS").is_some() {
631            return true;
632        }
633        std::env::var_os("CI").is_none()
634    }
635
636    fn image_has_content(path: &Path, width: u32, height: u32) -> bool {
637        let image = image::open(path).expect("saved image should be readable");
638        assert_eq!(image.dimensions(), (width, height));
639        let rgba = image.to_rgba8();
640        let mut unique = std::collections::BTreeSet::new();
641        for pixel in rgba.pixels() {
642            unique.insert(pixel.0);
643            if unique.len() > 8 {
644                break;
645            }
646        }
647        unique.len() > 1
648    }
649
650    fn assert_render_succeeds<F>(name: &str, mut render_once: F)
651    where
652        F: FnMut(&Path) -> Result<(), HeadlessRenderError>,
653    {
654        let attempts = if std::env::var_os("CI").is_some() {
655            5
656        } else {
657            1
658        };
659        for attempt in 0..attempts {
660            let path = temp_png(name);
661            match render_once(&path) {
662                Ok(()) => {
663                    if image_has_content(&path, 320, 240) {
664                        let _ = std::fs::remove_file(path);
665                        return;
666                    }
667                }
668                Err(err) if err.to_string().contains("Unable to find a GPU") => return,
669                Err(err) => panic!("headless export should succeed: {err}"),
670            }
671            let _ = std::fs::remove_file(&path);
672            if attempt + 1 == attempts {
673                panic!("rendered image should not be a flat color");
674            }
675        }
676    }
677
678    #[test]
679    fn exports_default_scene_to_png() {
680        if !should_run_headless_render_tests() {
681            return;
682        }
683        let _guard = HEADLESS_TEST_LOCK.lock().unwrap();
684        let config = ViewerConfig::default();
685        assert_render_succeeds("default", |path| {
686            export_structure_image(
687                fixture_structure(),
688                config.clone(),
689                test_render_config(path),
690            )
691        });
692    }
693
694    #[test]
695    fn exports_scripted_scene_with_shared_session_commands() {
696        if !should_run_headless_render_tests() {
697            return;
698        }
699        let _guard = HEADLESS_TEST_LOCK.lock().unwrap();
700        let mut config = ViewerConfig::default();
701        config.render.show_axes = false;
702        assert_render_succeeds("scripted", |path| {
703            export_image_with_session(
704                Trajectory::new(vec![fixture_structure()]),
705                config.clone(),
706                test_render_config(path),
707                |session| {
708                    session
709                        .set_bonds(BondList::new([(0, 1), (0, 2), (0, 3)]), Some(0))
710                        .unwrap();
711                    session
712                        .set_faces(
713                            FaceList::new([Face::new([1, 2, 3], [0.2, 0.6, 0.9, 0.35]).unwrap()]),
714                            Some(0),
715                        )
716                        .unwrap();
717                    session
718                        .set_render_style(
719                            super::super::RenderStyle::BallAndStick(
720                                super::super::BallAndStickStyle {
721                                    atom_scale: 0.5,
722                                    bond_radius: 0.06,
723                                    bond_color: [0.5, 0.5, 0.5, 1.0],
724                                    bond_scope: super::super::BondScope::TouchSelection,
725                                },
726                            ),
727                            vec![true, true, true, true],
728                            Some(0),
729                            false,
730                        )
731                        .unwrap();
732                    session.frame_all().unwrap();
733                },
734            )
735        });
736    }
737}