Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

prov/shm: Add unmap_region function #10364

Merged
merged 1 commit into from
Oct 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions prov/shm/src/smr_av.c
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,12 @@ static void smr_map_cleanup(struct smr_map *map)
{
int64_t i;

for (i = 0; i < SMR_MAX_PEERS; i++)
smr_map_del(map, i);
for (i = 0; i < SMR_MAX_PEERS; i++) {
if (map->peers[i].peer.id < 0)
continue;

smr_map_del(map, i);
}
ofi_rbmap_cleanup(&map->rbmap);
}

Expand Down Expand Up @@ -210,7 +213,6 @@ static int smr_av_remove(struct fid_av *av_fid, fi_addr_t *fi_addr, size_t count
dlist_foreach(&util_av->ep_list, av_entry) {
util_ep = container_of(av_entry, struct util_ep, av_entry);
smr_ep = container_of(util_ep, struct smr_ep, util_ep);
smr_unmap_from_endpoint(smr_ep->region, id);
if (smr_av->smr_map.num_peers > 0)
smr_ep->region->max_sar_buf_per_peer =
SMR_MAX_PEERS /
Expand Down
2 changes: 2 additions & 0 deletions prov/shm/src/smr_ep.c
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,9 @@ int64_t smr_verify_peer(struct smr_ep *ep, fi_addr_t fi_addr)
return id;

if (!ep->region->map->peers[id].region) {
ofi_spin_lock(&ep->region->map->lock);
ret = smr_map_to_region(&smr_prov, ep->region->map, id);
ofi_spin_unlock(&ep->region->map->lock);
if (ret)
return -1;
}
Expand Down
11 changes: 5 additions & 6 deletions prov/shm/src/smr_progress.c
Original file line number Diff line number Diff line change
Expand Up @@ -878,7 +878,9 @@ static void smr_progress_connreq(struct smr_ep *ep, struct smr_cmd *cmd)

peer_smr = smr_peer_region(ep->region, idx);
if (!peer_smr) {
ofi_spin_lock(&ep->region->map->lock);
ret = smr_map_to_region(&smr_prov, ep->region->map, idx);
ofi_spin_unlock(&ep->region->map->lock);
if (ret) {
FI_WARN(&smr_prov, FI_LOG_EP_CTRL,
"Could not map peer region\n");
Expand All @@ -891,14 +893,11 @@ static void smr_progress_connreq(struct smr_ep *ep, struct smr_cmd *cmd)
if (peer_smr->pid != (int) cmd->msg.hdr.data) {
/* TODO track and update/complete in error any transfers
* to or from old mapping
*
* TODO create smr_unmap_region
* this needs to close peer_smr->map->peers[idx].pid_fd
* This case will also return an unmapped region because the idx
* is valid but the region was unmapped
*/
munmap(peer_smr, peer_smr->total_size);
ofi_spin_lock(&ep->region->map->lock);
smr_unmap_region(&smr_prov, ep->region->map, idx, false);
smr_map_to_region(&smr_prov, ep->region->map, idx);
ofi_spin_unlock(&ep->region->map->lock);
peer_smr = smr_peer_region(ep->region, idx);
}

Expand Down
98 changes: 66 additions & 32 deletions prov/shm/src/smr_util.c
Original file line number Diff line number Diff line change
Expand Up @@ -367,16 +367,15 @@ int smr_map_to_region(const struct fi_provider *prov, struct smr_map *map,
}
pthread_mutex_unlock(&ep_list_lock);

ofi_spin_lock(&map->lock);
if (peer_buf->region)
goto unlock;
return FI_SUCCESS;
zachdworkin marked this conversation as resolved.
Show resolved Hide resolved

assert(ofi_spin_held(&map->lock));
fd = shm_open(name, O_RDWR, S_IRUSR | S_IWUSR);
if (fd < 0) {
ret = -errno;
FI_WARN_ONCE(prov, FI_LOG_AV,
"shm_open error: name %s errno %d\n", name, errno);
goto unlock;
return -errno;
}

memset(tmp, 0, sizeof(tmp));
Expand Down Expand Up @@ -437,8 +436,6 @@ int smr_map_to_region(const struct fi_provider *prov, struct smr_map *map,

out:
close(fd);
unlock:
ofi_spin_unlock(&map->lock);
return ret;
}

Expand All @@ -448,6 +445,7 @@ void smr_map_to_endpoint(struct smr_region *region, int64_t id)
struct smr_region *peer_smr;
struct smr_peer_data *local_peers;

assert(ofi_spin_held(&region->map->lock));
peer_smr = smr_peer_region(region, id);
if (region->map->peers[id].peer.id < 0 || !peer_smr)
return;
Expand Down Expand Up @@ -479,32 +477,81 @@ void smr_map_to_endpoint(struct smr_region *region, int64_t id)
return;
}

void smr_unmap_region(const struct fi_provider *prov, struct smr_map *map,
int64_t peer_id, bool local)
{
struct smr_region *peer_region;
struct smr_peer *peer;
struct util_ep *util_ep;
struct smr_ep *smr_ep;
struct smr_av *av;
int ret = 0;

zachdworkin marked this conversation as resolved.
Show resolved Hide resolved
assert(ofi_spin_held(&map->lock));
peer_region = map->peers[peer_id].region;
if (!peer_region)
return;

peer = &map->peers[peer_id];
av = container_of(map, struct smr_av, smr_map);
dlist_foreach_container(&av->util_av.ep_list, struct util_ep, util_ep,
av_entry) {
smr_ep = container_of(util_ep, struct smr_ep, util_ep);
smr_unmap_from_endpoint(smr_ep->region, peer_id);
}

/* Don't unmap memory owned by this pid because the endpoint it belongs
* to might still be active.
*/
if (local)
return;

if (map->flags & SMR_FLAG_HMEM_ENABLED) {
ret = ofi_hmem_host_unregister(peer_region);
if (ret)
FI_WARN(prov, FI_LOG_EP_CTRL,
"unable to unregister shm with iface\n");

if (peer->pid_fd != -1) {
close(peer->pid_fd);
peer->pid_fd = -1;
}
}

munmap(peer_region, peer_region->total_size);
peer->region = NULL;
}

void smr_unmap_from_endpoint(struct smr_region *region, int64_t id)
{
struct smr_region *peer_smr;
struct smr_peer_data *local_peers, *peer_peers;
int64_t peer_id;

local_peers = smr_peer_data(region);
if (region->map->peers[id].peer.id < 0)
return;

peer_smr = smr_peer_region(region, id);
peer_id = smr_peer_data(region)[id].addr.id;

assert(peer_smr);
peer_peers = smr_peer_data(peer_smr);
peer_id = smr_peer_data(region)[id].addr.id;

peer_peers[peer_id].addr.id = -1;
peer_peers[peer_id].name_sent = 0;

local_peers = smr_peer_data(region);
ofi_xpmem_release(&local_peers[peer_id].xpmem);
}

void smr_exchange_all_peers(struct smr_region *region)
{
int64_t i;

ofi_spin_lock(&region->map->lock);
for (i = 0; i < SMR_MAX_PEERS; i++)
smr_map_to_endpoint(region, i);

ofi_spin_unlock(&region->map->lock);
}

int smr_map_add(const struct fi_provider *prov, struct smr_map *map,
Expand Down Expand Up @@ -546,37 +593,24 @@ int smr_map_add(const struct fi_provider *prov, struct smr_map *map,

void smr_map_del(struct smr_map *map, int64_t id)
{
struct dlist_entry *entry;
struct smr_ep_name *name;
bool local = false;

assert(id >= 0 && id < SMR_MAX_PEERS);

pthread_mutex_lock(&ep_list_lock);
entry = dlist_find_first_match(&ep_name_list, smr_match_name,
smr_no_prefix(map->peers[id].peer.name));
dlist_foreach_container(&ep_name_list, struct smr_ep_name, name, entry) {
if (strcmp(name->name, map->peers[id].peer.name)) {
local = true;
break;
}
}
pthread_mutex_unlock(&ep_list_lock);

ofi_spin_lock(&map->lock);
(void) ofi_rbmap_find_delete(&map->rbmap,
(void *) map->peers[id].peer.name);

smr_unmap_region(&smr_prov, map, id, local);
map->peers[id].fiaddr = FI_ADDR_NOTAVAIL;
map->peers[id].peer.id = -1;
map->num_peers--;

if (!map->peers[id].region)
goto unlock;

if (!entry) {
if (map->flags & SMR_FLAG_HMEM_ENABLED) {
if (map->peers[id].pid_fd != -1)
close(map->peers[id].pid_fd);

(void) ofi_hmem_host_unregister(map->peers[id].region);
}
munmap(map->peers[id].region, map->peers[id].region->total_size);
map->peers[id].region = NULL;
}
unlock:
ofi_rbmap_find_delete(&map->rbmap, map->peers[id].peer.name);
ofi_spin_unlock(&map->lock);
}

Expand Down
2 changes: 2 additions & 0 deletions prov/shm/src/smr_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,8 @@ void smr_cleanup(void);
int smr_map_to_region(const struct fi_provider *prov, struct smr_map *map,
int64_t id);
void smr_map_to_endpoint(struct smr_region *region, int64_t id);
void smr_unmap_region(const struct fi_provider *prov, struct smr_map *map,
int64_t id, bool found);
void smr_unmap_from_endpoint(struct smr_region *region, int64_t id);
void smr_exchange_all_peers(struct smr_region *region);
int smr_map_add(const struct fi_provider *prov, struct smr_map *map,
Expand Down