From 4cb20dc043cf70b8a1b4846c86599cc1ff9680d9 Mon Sep 17 00:00:00 2001
From: Tobias Burnus <tburnus@baylibre.com>
Date: Tue, 24 Sep 2024 17:41:39 +0200
Subject: [PATCH] libgomp: with USM, init 'link' variables with host address

If requires unified_shared_memory or self_maps is set, make
'declare target link' variables to point initially to the host pointer.

libgomp/ChangeLog:

	* target.c (gomp_load_image_to_device): For requires
	unified_shared_memory, update 'link' vars to point to the host var.
	* testsuite/libgomp.c-c++-common/target-link-3.c: New test.
	* testsuite/libgomp.c-c++-common/target-link-4.c: New test.
---
 libgomp/target.c                              |  6 +++
 .../libgomp.c-c++-common/target-link-3.c      | 52 +++++++++++++++++++
 .../libgomp.c-c++-common/target-link-4.c      | 52 +++++++++++++++++++
 3 files changed, 110 insertions(+)
 create mode 100644 libgomp/testsuite/libgomp.c-c++-common/target-link-3.c
 create mode 100644 libgomp/testsuite/libgomp.c-c++-common/target-link-4.c

diff --git a/libgomp/target.c b/libgomp/target.c
index 6918694a843b..cf62af61f3b6 100644
--- a/libgomp/target.c
+++ b/libgomp/target.c
@@ -2454,6 +2454,12 @@ gomp_load_image_to_device (struct gomp_device_descr *devicep, unsigned version,
       array->right = NULL;
       splay_tree_insert (&devicep->mem_map, array);
       array++;
+
+      if (is_link_var
+	  && (omp_requires_mask
+	      & (GOMP_REQUIRES_UNIFIED_SHARED_MEMORY | GOMP_REQUIRES_SELF_MAPS)))
+	gomp_copy_host2dev (devicep, NULL, (void *) target_var->start,
+			    &k->host_start, sizeof (void *), false, NULL);
     }
 
   /* Last entry is for the ICV struct variable; if absent, start = end = 0.  */
diff --git a/libgomp/testsuite/libgomp.c-c++-common/target-link-3.c b/libgomp/testsuite/libgomp.c-c++-common/target-link-3.c
new file mode 100644
index 000000000000..c707b38b7d46
--- /dev/null
+++ b/libgomp/testsuite/libgomp.c-c++-common/target-link-3.c
@@ -0,0 +1,52 @@
+/* { dg-do run }  */
+
+#include <stdint.h>
+#include <omp.h>
+
+#pragma omp requires unified_shared_memory
+
+int A[3] = {-3,-4,-5};
+static int q = -401;
+#pragma omp declare target link(A, q)
+
+#pragma omp begin declare target
+void
+f (uintptr_t *pA, uintptr_t *pq)
+{
+  if (A[0] != 1 || A[1] != 2 || A[2] != 3 || q != 42)
+    __builtin_abort ();
+  A[0] = 13;
+  A[1] = 14;
+  A[2] = 15;
+  q = 23;
+  *pA = (uintptr_t) &A[0];
+  *pq = (uintptr_t) &q;
+}
+#pragma omp end declare target
+
+int
+main ()
+{
+  uintptr_t hpA = (uintptr_t) &A[0];
+  uintptr_t hpq = (uintptr_t) &q;
+  uintptr_t dpA, dpq;
+
+  A[0] = 1;
+  A[1] = 2;
+  A[2] = 3;
+  q = 42;
+
+  for (int i = 0; i <= omp_get_num_devices (); ++i)
+    {
+      #pragma omp target device(device_num: i) map(dpA, dpq)
+	f (&dpA, &dpq);
+      if (hpA != dpA || hpq != dpq)
+	__builtin_abort ();
+      if (A[0] != 13 || A[1] != 14 || A[2] != 15 || q != 23)
+	__builtin_abort ();
+      A[0] = 1;
+      A[1] = 2;
+      A[2] = 3;
+      q = 42;
+    }
+}
diff --git a/libgomp/testsuite/libgomp.c-c++-common/target-link-4.c b/libgomp/testsuite/libgomp.c-c++-common/target-link-4.c
new file mode 100644
index 000000000000..785055e216d7
--- /dev/null
+++ b/libgomp/testsuite/libgomp.c-c++-common/target-link-4.c
@@ -0,0 +1,52 @@
+/* { dg-do run }  */
+
+#include <stdint.h>
+#include <omp.h>
+
+#pragma omp requires self_maps
+
+int A[3] = {-3,-4,-5};
+static int q = -401;
+#pragma omp declare target link(A, q)
+
+#pragma omp begin declare target
+void
+f (uintptr_t *pA, uintptr_t *pq)
+{
+  if (A[0] != 1 || A[1] != 2 || A[2] != 3 || q != 42)
+    __builtin_abort ();
+  A[0] = 13;
+  A[1] = 14;
+  A[2] = 15;
+  q = 23;
+  *pA = (uintptr_t) &A[0];
+  *pq = (uintptr_t) &q;
+}
+#pragma omp end declare target
+
+int
+main ()
+{
+  uintptr_t hpA = (uintptr_t) &A[0];
+  uintptr_t hpq = (uintptr_t) &q;
+  uintptr_t dpA, dpq;
+
+  A[0] = 1;
+  A[1] = 2;
+  A[2] = 3;
+  q = 42;
+
+  for (int i = 0; i <= omp_get_num_devices (); ++i)
+    {
+      #pragma omp target device(device_num: i) map(dpA, dpq)
+	f (&dpA, &dpq);
+      if (hpA != dpA || hpq != dpq)
+	__builtin_abort ();
+      if (A[0] != 13 || A[1] != 14 || A[2] != 15 || q != 23)
+	__builtin_abort ();
+      A[0] = 1;
+      A[1] = 2;
+      A[2] = 3;
+      q = 42;
+    }
+}
-- 
GitLab